Compare commits
3 Commits
fa966d30db
...
0a27b9a8d3
Author | SHA1 | Date | |
---|---|---|---|
0a27b9a8d3 | |||
2611663168 | |||
120e61e88b |
10
go.mod
10
go.mod
@ -4,8 +4,8 @@ go 1.21
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/alecthomas/chroma/v2 v2.11.1
|
github.com/alecthomas/chroma/v2 v2.11.1
|
||||||
|
github.com/charmbracelet/lipgloss v0.10.0
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||||
github.com/gookit/color v1.5.4
|
|
||||||
github.com/sashabaranov/go-openai v1.17.7
|
github.com/sashabaranov/go-openai v1.17.7
|
||||||
github.com/spf13/cobra v1.8.0
|
github.com/spf13/cobra v1.8.0
|
||||||
github.com/sqids/sqids-go v0.4.1
|
github.com/sqids/sqids-go v0.4.1
|
||||||
@ -14,14 +14,20 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/kr/pretty v0.3.1 // indirect
|
github.com/kr/pretty v0.3.1 // indirect
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.18 // indirect
|
||||||
|
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
||||||
|
github.com/muesli/reflow v0.3.0 // indirect
|
||||||
|
github.com/muesli/termenv v0.15.2 // indirect
|
||||||
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
|
|
||||||
golang.org/x/sys v0.14.0 // indirect
|
golang.org/x/sys v0.14.0 // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||||
gopkg.in/yaml.v2 v2.2.2 // indirect
|
gopkg.in/yaml.v2 v2.2.2 // indirect
|
||||||
|
31
go.sum
31
go.sum
@ -4,16 +4,16 @@ github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJW
|
|||||||
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
|
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
|
||||||
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
|
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
|
||||||
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||||
|
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
|
||||||
|
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
|
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
|
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
|
||||||
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
|
|
||||||
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
|
|
||||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
@ -26,11 +26,24 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
|||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||||
|
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
|
||||||
|
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||||
|
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||||
|
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
|
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
|
||||||
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
|
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||||
|
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||||
|
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
|
||||||
|
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
@ -42,10 +55,7 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
|||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
||||||
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
||||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
|
||||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8=
|
|
||||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
|
|
||||||
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
||||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
@ -53,7 +63,6 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogR
|
|||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
||||||
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
||||||
|
17
main.go
17
main.go
@ -1,15 +1,18 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"os"
|
"git.mlow.ca/mlow/lmcli/pkg/cmd"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/cli"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if err := cli.Execute(); err != nil {
|
ctx, err := lmcli.NewContext()
|
||||||
fmt.Fprintln(os.Stderr, err.Error())
|
if err != nil {
|
||||||
os.Exit(1)
|
lmcli.Fatal("%v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
root := cmd.RootCmd(ctx)
|
||||||
|
if err := root.Execute(); err != nil {
|
||||||
|
lmcli.Fatal("%v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
var config *Config
|
|
||||||
var store *Store
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
config, err = NewConfig()
|
|
||||||
if err != nil {
|
|
||||||
Fatal("%v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err = NewStore()
|
|
||||||
if err != nil {
|
|
||||||
Fatal("%v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Fatal(format string, args ...any) {
|
|
||||||
fmt.Fprintf(os.Stderr, format, args...)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Warn(format string, args ...any) {
|
|
||||||
fmt.Fprintf(os.Stderr, format, args...)
|
|
||||||
}
|
|
719
pkg/cli/cmd.go
719
pkg/cli/cmd.go
@ -1,719 +0,0 @@
|
|||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
maxTokens int
|
|
||||||
model string
|
|
||||||
systemPrompt string
|
|
||||||
systemPromptFile string
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Limit number of conversations shown with `ls`, without --all
|
|
||||||
LS_LIMIT int = 25
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
|
|
||||||
for _, cmd := range inputCmds {
|
|
||||||
cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Maximum response tokens")
|
|
||||||
cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "Which model to use model")
|
|
||||||
cmd.Flags().StringVar(&systemPrompt, "system-prompt", *config.ModelDefaults.SystemPrompt, "System prompt")
|
|
||||||
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
|
||||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
|
||||||
}
|
|
||||||
|
|
||||||
listCmd.Flags().Bool("all", false, fmt.Sprintf("Show all conversations, by default only the last %d are shown", LS_LIMIT))
|
|
||||||
renameCmd.Flags().Bool("generate", false, "Generate a conversation title")
|
|
||||||
editCmd.Flags().Int("offset", 1, "Offset from the last reply to edit (Default: edit your last message, assuming there's an assistant reply)")
|
|
||||||
|
|
||||||
rootCmd.AddCommand(
|
|
||||||
cloneCmd,
|
|
||||||
continueCmd,
|
|
||||||
editCmd,
|
|
||||||
listCmd,
|
|
||||||
newCmd,
|
|
||||||
promptCmd,
|
|
||||||
renameCmd,
|
|
||||||
replyCmd,
|
|
||||||
retryCmd,
|
|
||||||
rmCmd,
|
|
||||||
viewCmd,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Execute() error {
|
|
||||||
return rootCmd.Execute()
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSystemPrompt() string {
|
|
||||||
if systemPromptFile != "" {
|
|
||||||
content, err := FileContents(systemPromptFile)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
|
|
||||||
}
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
return systemPrompt
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
|
||||||
// the response to stdout. Returns all model reply messages.
|
|
||||||
func fetchAndShowCompletion(messages []Message) ([]Message, error) {
|
|
||||||
content := make(chan string) // receives the reponse from LLM
|
|
||||||
defer close(content)
|
|
||||||
|
|
||||||
// render all content received over the channel
|
|
||||||
go ShowDelayedContent(content)
|
|
||||||
|
|
||||||
var replies []Message
|
|
||||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, content, &replies)
|
|
||||||
if response != "" {
|
|
||||||
// there was some content, so break to a new line after it
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
Warn("Received partial response. Error: %v\n", err)
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return replies, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupConversation either returns the conversation found by the
|
|
||||||
// short name or exits the program
|
|
||||||
func lookupConversation(shortName string) *Conversation {
|
|
||||||
c, err := store.ConversationByShortName(shortName)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not lookup conversation: %v\n", err)
|
|
||||||
}
|
|
||||||
if c.ID == 0 {
|
|
||||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func lookupConversationE(shortName string) (*Conversation, error) {
|
|
||||||
c, err := store.ConversationByShortName(shortName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
|
||||||
}
|
|
||||||
if c.ID == 0 {
|
|
||||||
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleConversationReply handles sending messages to an existing
|
|
||||||
// conversation, optionally persisting both the sent replies and responses.
|
|
||||||
func handleConversationReply(c *Conversation, persist bool, toSend ...Message) {
|
|
||||||
existing, err := store.Messages(c)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
|
||||||
}
|
|
||||||
|
|
||||||
if persist {
|
|
||||||
for _, message := range toSend {
|
|
||||||
err = store.SaveMessage(&message)
|
|
||||||
if err != nil {
|
|
||||||
Warn("Could not save %s message: %v\n", message.Role, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
allMessages := append(existing, toSend...)
|
|
||||||
|
|
||||||
RenderConversation(allMessages, true)
|
|
||||||
|
|
||||||
// render a message header with no contents
|
|
||||||
(&Message{Role: MessageRoleAssistant}).RenderTTY()
|
|
||||||
|
|
||||||
replies, err := fetchAndShowCompletion(allMessages)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if persist {
|
|
||||||
for _, reply := range replies {
|
|
||||||
reply.ConversationID = c.ID
|
|
||||||
|
|
||||||
err = store.SaveMessage(&reply)
|
|
||||||
if err != nil {
|
|
||||||
Warn("Could not save reply: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// inputFromArgsOrEditor returns either the provided input from the args slice
|
|
||||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
|
||||||
// whatever input was provided there. placeholder is a string which populates
|
|
||||||
// the editor and gets stripped from the final output.
|
|
||||||
func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) {
|
|
||||||
var err error
|
|
||||||
if len(args) == 0 {
|
|
||||||
message, err = InputFromEditor(placeholder, "message.*.md", existingMessage)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Failed to get input: %v\n", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
message = strings.Trim(strings.Join(args, " "), " \t\n")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var rootCmd = &cobra.Command{
|
|
||||||
Use: "lmcli <command> [flags]",
|
|
||||||
Long: `lmcli - Large Language Model CLI`,
|
|
||||||
SilenceErrors: true,
|
|
||||||
SilenceUsage: true,
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
cmd.Usage()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var listCmd = &cobra.Command{
|
|
||||||
Use: "list",
|
|
||||||
Aliases: []string{"ls"},
|
|
||||||
Short: "List conversations",
|
|
||||||
Long: `List conversations in order of recent activity`,
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
conversations, err := store.Conversations()
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not fetch conversations.\n")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type Category struct {
|
|
||||||
name string
|
|
||||||
cutoff time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConversationLine struct {
|
|
||||||
timeSinceReply time.Duration
|
|
||||||
formatted string
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
|
||||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
|
||||||
dayOfWeek := int(now.Weekday())
|
|
||||||
categories := []Category{
|
|
||||||
{"today", now.Sub(midnight)},
|
|
||||||
{"yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
|
|
||||||
{"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
|
|
||||||
{"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
|
|
||||||
{"this month", now.Sub(monthStart)},
|
|
||||||
{"last month", now.Sub(monthStart.AddDate(0, -1, 0))},
|
|
||||||
{"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
|
|
||||||
{"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
|
|
||||||
{"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
|
|
||||||
{"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
|
|
||||||
{"older", now.Sub(time.Time{})},
|
|
||||||
}
|
|
||||||
categorized := map[string][]ConversationLine{}
|
|
||||||
|
|
||||||
all, _ := cmd.Flags().GetBool("all")
|
|
||||||
|
|
||||||
for _, conversation := range conversations {
|
|
||||||
lastMessage, err := store.LastMessage(&conversation)
|
|
||||||
if lastMessage == nil || err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
|
||||||
|
|
||||||
var category string
|
|
||||||
for _, c := range categories {
|
|
||||||
if messageAge < c.cutoff {
|
|
||||||
category = c.name
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted := fmt.Sprintf(
|
|
||||||
"%s - %s - %s",
|
|
||||||
conversation.ShortName.String,
|
|
||||||
humanTimeElapsedSince(messageAge),
|
|
||||||
conversation.Title,
|
|
||||||
)
|
|
||||||
|
|
||||||
categorized[category] = append(
|
|
||||||
categorized[category],
|
|
||||||
ConversationLine{messageAge, formatted},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
var conversationsPrinted int
|
|
||||||
outer:
|
|
||||||
for _, category := range categories {
|
|
||||||
conversations, ok := categorized[category.name]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
slices.SortFunc(conversations, func(a, b ConversationLine) int {
|
|
||||||
return int(a.timeSinceReply - b.timeSinceReply)
|
|
||||||
})
|
|
||||||
|
|
||||||
fmt.Printf("%s:\n", category.name)
|
|
||||||
for _, conv := range conversations {
|
|
||||||
if conversationsPrinted >= LS_LIMIT && !all {
|
|
||||||
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
|
||||||
break outer
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf(" %s\n", conv.formatted)
|
|
||||||
conversationsPrinted++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var rmCmd = &cobra.Command{
|
|
||||||
Use: "rm <conversation>...",
|
|
||||||
Short: "Remove conversations",
|
|
||||||
Long: `Remove conversations by their short names.`,
|
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
|
||||||
argCount := 1
|
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
var toRemove []*Conversation
|
|
||||||
for _, shortName := range args {
|
|
||||||
conversation := lookupConversation(shortName)
|
|
||||||
toRemove = append(toRemove, conversation)
|
|
||||||
}
|
|
||||||
var errors []error
|
|
||||||
for _, c := range toRemove {
|
|
||||||
err := store.DeleteConversation(c)
|
|
||||||
if err != nil {
|
|
||||||
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, err := range errors {
|
|
||||||
fmt.Fprintln(os.Stderr, err.Error())
|
|
||||||
}
|
|
||||||
if len(errors) > 0 {
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
|
||||||
var completions []string
|
|
||||||
outer:
|
|
||||||
for _, completion := range store.ConversationShortNameCompletions(toComplete) {
|
|
||||||
parts := strings.Split(completion, "\t")
|
|
||||||
for _, arg := range args {
|
|
||||||
if parts[0] == arg {
|
|
||||||
continue outer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
completions = append(completions, completion)
|
|
||||||
}
|
|
||||||
return completions, compMode
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var cloneCmd = &cobra.Command{
|
|
||||||
Use: "clone <conversation>",
|
|
||||||
Short: "Clone conversations",
|
|
||||||
Long: `Clones the provided conversation.`,
|
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
|
||||||
argCount := 1
|
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
shortName := args[0]
|
|
||||||
toClone, err := lookupConversationE(shortName)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
messagesToCopy, err := store.Messages(toClone)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
|
||||||
}
|
|
||||||
|
|
||||||
clone := &Conversation{
|
|
||||||
Title: toClone.Title + " - Clone",
|
|
||||||
}
|
|
||||||
if err := store.SaveConversation(clone); err != nil {
|
|
||||||
return fmt.Errorf("Cloud not create clone: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errors []error
|
|
||||||
messageCnt := 0
|
|
||||||
for _, message := range messagesToCopy {
|
|
||||||
newMessage := message
|
|
||||||
newMessage.ConversationID = clone.ID
|
|
||||||
newMessage.ID = 0
|
|
||||||
if err := store.SaveMessage(&newMessage); err != nil {
|
|
||||||
errors = append(errors, err)
|
|
||||||
} else {
|
|
||||||
messageCnt++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
|
||||||
if len(args) != 0 {
|
|
||||||
return nil, compMode
|
|
||||||
}
|
|
||||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var viewCmd = &cobra.Command{
|
|
||||||
Use: "view <conversation>",
|
|
||||||
Short: "View messages in a conversation",
|
|
||||||
Long: `Finds a conversation by its short name and displays its contents.`,
|
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
|
||||||
argCount := 1
|
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
shortName := args[0]
|
|
||||||
conversation := lookupConversation(shortName)
|
|
||||||
|
|
||||||
messages, err := store.Messages(conversation)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
|
||||||
}
|
|
||||||
|
|
||||||
RenderConversation(messages, false)
|
|
||||||
},
|
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
|
||||||
if len(args) != 0 {
|
|
||||||
return nil, compMode
|
|
||||||
}
|
|
||||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var renameCmd = &cobra.Command{
|
|
||||||
Use: "rename <conversation> [title]",
|
|
||||||
Short: "Rename a conversation",
|
|
||||||
Long: `Renames a conversation, either with the provided title or by generating a new name.`,
|
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
|
||||||
argCount := 1
|
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
shortName := args[0]
|
|
||||||
conversation := lookupConversation(shortName)
|
|
||||||
var err error
|
|
||||||
|
|
||||||
generate, _ := cmd.Flags().GetBool("generate")
|
|
||||||
var title string
|
|
||||||
if generate {
|
|
||||||
title, err = conversation.GenerateTitle()
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not generate conversation title: %v\n", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if len(args) < 2 {
|
|
||||||
Fatal("Conversation title not provided.\n")
|
|
||||||
}
|
|
||||||
title = strings.Join(args[1:], " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation.Title = title
|
|
||||||
err = store.SaveConversation(conversation)
|
|
||||||
if err != nil {
|
|
||||||
Warn("Could not save conversation with new title: %v\n", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
|
||||||
if len(args) != 0 {
|
|
||||||
return nil, compMode
|
|
||||||
}
|
|
||||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var replyCmd = &cobra.Command{
|
|
||||||
Use: "reply <conversation> [message]",
|
|
||||||
Short: "Reply to a conversation",
|
|
||||||
Long: `Sends a reply to conversation and writes the response to stdout.`,
|
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
|
||||||
argCount := 1
|
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
shortName := args[0]
|
|
||||||
conversation := lookupConversation(shortName)
|
|
||||||
|
|
||||||
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
|
||||||
if reply == "" {
|
|
||||||
Fatal("No reply was provided.\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
handleConversationReply(conversation, true, Message{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: MessageRoleUser,
|
|
||||||
OriginalContent: reply,
|
|
||||||
})
|
|
||||||
},
|
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
|
||||||
if len(args) != 0 {
|
|
||||||
return nil, compMode
|
|
||||||
}
|
|
||||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var newCmd = &cobra.Command{
|
|
||||||
Use: "new [message]",
|
|
||||||
Short: "Start a new conversation",
|
|
||||||
Long: `Start a new conversation with the Large Language Model.`,
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
|
||||||
if messageContents == "" {
|
|
||||||
Fatal("No message was provided.\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation := &Conversation{}
|
|
||||||
err := store.SaveConversation(conversation)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not save new conversation: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := []Message{
|
|
||||||
{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: MessageRoleSystem,
|
|
||||||
OriginalContent: getSystemPrompt(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: MessageRoleUser,
|
|
||||||
OriginalContent: messageContents,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
handleConversationReply(conversation, true, messages...)
|
|
||||||
|
|
||||||
title, err := conversation.GenerateTitle()
|
|
||||||
if err != nil {
|
|
||||||
Warn("Could not generate title for conversation: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation.Title = title
|
|
||||||
|
|
||||||
err = store.SaveConversation(conversation)
|
|
||||||
if err != nil {
|
|
||||||
Warn("Could not save conversation after generating title: %v\n", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var promptCmd = &cobra.Command{
|
|
||||||
Use: "prompt [message]",
|
|
||||||
Short: "Do a one-shot prompt",
|
|
||||||
Long: `Prompt the Large Language Model and get a response.`,
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
|
||||||
if message == "" {
|
|
||||||
Fatal("No message was provided.\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := []Message{
|
|
||||||
{
|
|
||||||
Role: MessageRoleSystem,
|
|
||||||
OriginalContent: getSystemPrompt(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: MessageRoleUser,
|
|
||||||
OriginalContent: message,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := fetchAndShowCompletion(messages)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Error fetching LLM response: %v\n", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var retryCmd = &cobra.Command{
|
|
||||||
Use: "retry <conversation>",
|
|
||||||
Short: "Retry the last user reply in a conversation",
|
|
||||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
|
||||||
argCount := 1
|
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
shortName := args[0]
|
|
||||||
conversation := lookupConversation(shortName)
|
|
||||||
|
|
||||||
messages, err := store.Messages(conversation)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
|
||||||
}
|
|
||||||
|
|
||||||
// walk backwards through the conversation and delete messages, break
|
|
||||||
// when we find the latest user response
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
if messages[i].Role == MessageRoleUser {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
err = store.DeleteMessage(&messages[i])
|
|
||||||
if err != nil {
|
|
||||||
Warn("Could not delete previous reply: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
handleConversationReply(conversation, true)
|
|
||||||
},
|
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
|
||||||
if len(args) != 0 {
|
|
||||||
return nil, compMode
|
|
||||||
}
|
|
||||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var continueCmd = &cobra.Command{
|
|
||||||
Use: "continue <conversation>",
|
|
||||||
Short: "Continue a conversation from the last message",
|
|
||||||
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
|
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
|
||||||
argCount := 1
|
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
shortName := args[0]
|
|
||||||
conversation := lookupConversation(shortName)
|
|
||||||
handleConversationReply(conversation, true)
|
|
||||||
},
|
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
|
||||||
if len(args) != 0 {
|
|
||||||
return nil, compMode
|
|
||||||
}
|
|
||||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var editCmd = &cobra.Command{
|
|
||||||
Use: "edit <conversation>",
|
|
||||||
Short: "Edit the last user reply in a conversation",
|
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
|
||||||
argCount := 1
|
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
|
||||||
shortName := args[0]
|
|
||||||
conversation := lookupConversation(shortName)
|
|
||||||
|
|
||||||
messages, err := store.Messages(conversation)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
|
||||||
}
|
|
||||||
|
|
||||||
offset, _ := cmd.Flags().GetInt("offset")
|
|
||||||
if offset < 0 {
|
|
||||||
offset = -offset
|
|
||||||
}
|
|
||||||
|
|
||||||
if offset > len(messages) - 1 {
|
|
||||||
Fatal("Offset %d is before the start of the conversation\n", offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
desiredIdx := len(messages) - 1 - offset
|
|
||||||
|
|
||||||
// walk backwards through the conversation deleting messages until and
|
|
||||||
// including the last user message
|
|
||||||
toRemove := []Message{}
|
|
||||||
var toEdit *Message
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
if i == desiredIdx {
|
|
||||||
toEdit = &messages[i]
|
|
||||||
}
|
|
||||||
toRemove = append(toRemove, messages[i])
|
|
||||||
messages = messages[:i]
|
|
||||||
if toEdit != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
existingContents := toEdit.OriginalContent
|
|
||||||
|
|
||||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", existingContents)
|
|
||||||
switch newContents {
|
|
||||||
case existingContents:
|
|
||||||
Fatal("No edits were made.\n")
|
|
||||||
case "":
|
|
||||||
Fatal("No message was provided.\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, message := range toRemove {
|
|
||||||
err = store.DeleteMessage(&message)
|
|
||||||
if err != nil {
|
|
||||||
Warn("Could not delete message: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
handleConversationReply(conversation, true, Message{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: MessageRoleUser,
|
|
||||||
OriginalContent: newContents,
|
|
||||||
})
|
|
||||||
},
|
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
|
||||||
if len(args) != 0 {
|
|
||||||
return nil, compMode
|
|
||||||
}
|
|
||||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
|
||||||
},
|
|
||||||
}
|
|
@ -1,67 +0,0 @@
|
|||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MessageRole string
|
|
||||||
|
|
||||||
const (
|
|
||||||
MessageRoleUser MessageRole = "user"
|
|
||||||
MessageRoleAssistant MessageRole = "assistant"
|
|
||||||
MessageRoleSystem MessageRole = "system"
|
|
||||||
)
|
|
||||||
|
|
||||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
|
||||||
func (m *Message) FriendlyRole() string {
|
|
||||||
var friendlyRole string
|
|
||||||
switch m.Role {
|
|
||||||
case MessageRoleUser:
|
|
||||||
friendlyRole = "You"
|
|
||||||
case MessageRoleSystem:
|
|
||||||
friendlyRole = "System"
|
|
||||||
case MessageRoleAssistant:
|
|
||||||
friendlyRole = "Assistant"
|
|
||||||
default:
|
|
||||||
friendlyRole = string(m.Role)
|
|
||||||
}
|
|
||||||
return friendlyRole
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conversation) GenerateTitle() (string, error) {
|
|
||||||
messages, err := store.Messages(c)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
const header = "Generate a concise 4-5 word title for the conversation below."
|
|
||||||
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, formatForExternalPrompting(messages, false))
|
|
||||||
|
|
||||||
generateRequest := []Message{
|
|
||||||
{
|
|
||||||
Role: MessageRoleUser,
|
|
||||||
OriginalContent: prompt,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
model := "gpt-3.5-turbo" // use cheap model to generate title
|
|
||||||
response, err := CreateChatCompletion(model, generateRequest, 25, nil)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatForExternalPrompting(messages []Message, system bool) string {
|
|
||||||
sb := strings.Builder{}
|
|
||||||
for _, message := range messages {
|
|
||||||
if message.Role == MessageRoleSystem && !system {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole()))
|
|
||||||
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent))
|
|
||||||
}
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
@ -1,582 +0,0 @@
|
|||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
|
||||||
)
|
|
||||||
|
|
||||||
type FunctionResult struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Result any `json:"result,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionParameter struct {
|
|
||||||
Type string `json:"type"` // "string", "integer", "boolean"
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionParameters struct {
|
|
||||||
Type string `json:"type"` // "object"
|
|
||||||
Properties map[string]FunctionParameter `json:"properties"`
|
|
||||||
Required []string `json:"required,omitempty"` // required function parameter names
|
|
||||||
}
|
|
||||||
|
|
||||||
type AvailableTool struct {
|
|
||||||
openai.Tool
|
|
||||||
// The tool's implementation. Returns a string, as tool call results
|
|
||||||
// are treated as normal messages with string contents.
|
|
||||||
Impl func(arguments map[string]interface{}) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
|
||||||
|
|
||||||
Results are returned as JSON in the following format:
|
|
||||||
{
|
|
||||||
"message": "success", // if successful, or a different message indicating failure
|
|
||||||
// result may be an empty array if there are no files in the directory
|
|
||||||
"result": [
|
|
||||||
{"name": "a_file", "type": "file", "size": 123},
|
|
||||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
|
||||||
... // more files or directories
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
For files, size represents the size (in bytes) of the file.
|
|
||||||
For directories, size represents the number of entries in that directory.`
|
|
||||||
|
|
||||||
READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
|
||||||
|
|
||||||
Each line of the file is prefixed with its line number and a tabs (\t) to make
|
|
||||||
it make it easier to see which lines to change for other modifications.
|
|
||||||
|
|
||||||
Example result:
|
|
||||||
{
|
|
||||||
"message": "success", // if successful, or a different message indicating failure
|
|
||||||
"result": "1\tthe contents\n2\tof the file\n"
|
|
||||||
}`
|
|
||||||
|
|
||||||
WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
|
||||||
|
|
||||||
Note: only use this tool when you've been explicitly asked to create or write to a file.
|
|
||||||
|
|
||||||
When using this function, you do not need to share the content you intend to write with the user first.
|
|
||||||
|
|
||||||
Example result:
|
|
||||||
{
|
|
||||||
"message": "success", // if successful, or a different message indicating failure
|
|
||||||
}`
|
|
||||||
|
|
||||||
FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
|
||||||
|
|
||||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
|
||||||
|
|
||||||
FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
|
||||||
|
|
||||||
Useful for re-writing snippets/blocks of code or entire functions.
|
|
||||||
|
|
||||||
Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.`
|
|
||||||
)
|
|
||||||
|
|
||||||
var AvailableTools = map[string]AvailableTool{
|
|
||||||
"read_dir": {
|
|
||||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
|
||||||
Name: "read_dir",
|
|
||||||
Description: READ_DIR_DESCRIPTION,
|
|
||||||
Parameters: FunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]FunctionParameter{
|
|
||||||
"relative_dir": {
|
|
||||||
Type: "string",
|
|
||||||
Description: "If set, read the contents of a directory relative to the current one.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
Impl: func(args map[string]interface{}) (string, error) {
|
|
||||||
var relativeDir string
|
|
||||||
tmp, ok := args["relative_dir"]
|
|
||||||
if ok {
|
|
||||||
relativeDir, ok = tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ReadDir(relativeDir), nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"read_file": {
|
|
||||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
|
||||||
Name: "read_file",
|
|
||||||
Description: READ_FILE_DESCRIPTION,
|
|
||||||
Parameters: FunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]FunctionParameter{
|
|
||||||
"path": {
|
|
||||||
Type: "string",
|
|
||||||
Description: "Path to a file within the current working directory to read.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Required: []string{"path"},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
Impl: func(args map[string]interface{}) (string, error) {
|
|
||||||
tmp, ok := args["path"]
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
|
||||||
}
|
|
||||||
path, ok := tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
return ReadFile(path), nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"write_file": {
|
|
||||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
|
||||||
Name: "write_file",
|
|
||||||
Description: WRITE_FILE_DESCRIPTION,
|
|
||||||
Parameters: FunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]FunctionParameter{
|
|
||||||
"path": {
|
|
||||||
Type: "string",
|
|
||||||
Description: "Path to a file within the current working directory to write to.",
|
|
||||||
},
|
|
||||||
"content": {
|
|
||||||
Type: "string",
|
|
||||||
Description: "The content to write to the file. Overwrites any existing content!",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Required: []string{"path", "content"},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
Impl: func(args map[string]interface{}) (string, error) {
|
|
||||||
tmp, ok := args["path"]
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
|
||||||
}
|
|
||||||
path, ok := tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
tmp, ok = args["content"]
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
|
||||||
}
|
|
||||||
content, ok := tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
return WriteFile(path, content), nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"file_insert_lines": {
|
|
||||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
|
||||||
Name: "file_insert_lines",
|
|
||||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
|
||||||
Parameters: FunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]FunctionParameter{
|
|
||||||
"path": {
|
|
||||||
Type: "string",
|
|
||||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
|
||||||
},
|
|
||||||
"position": {
|
|
||||||
Type: "integer",
|
|
||||||
Description: `Which line to insert content *before*.`,
|
|
||||||
},
|
|
||||||
"content": {
|
|
||||||
Type: "string",
|
|
||||||
Description: `The content to insert.`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Required: []string{"path", "position", "content"},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
Impl: func(args map[string]interface{}) (string, error) {
|
|
||||||
tmp, ok := args["path"]
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
|
||||||
}
|
|
||||||
path, ok := tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
var position int
|
|
||||||
tmp, ok = args["position"]
|
|
||||||
if ok {
|
|
||||||
tmp, ok := tmp.(float64)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
position = int(tmp)
|
|
||||||
}
|
|
||||||
var content string
|
|
||||||
tmp, ok = args["content"]
|
|
||||||
if ok {
|
|
||||||
content, ok = tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return FileInsertLines(path, position, content), nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"file_replace_lines": {
|
|
||||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
|
||||||
Name: "file_replace_lines",
|
|
||||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
|
||||||
Parameters: FunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]FunctionParameter{
|
|
||||||
"path": {
|
|
||||||
Type: "string",
|
|
||||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
|
||||||
},
|
|
||||||
"start_line": {
|
|
||||||
Type: "integer",
|
|
||||||
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
|
||||||
},
|
|
||||||
"end_line": {
|
|
||||||
Type: "integer",
|
|
||||||
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
|
||||||
},
|
|
||||||
"content": {
|
|
||||||
Type: "string",
|
|
||||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Required: []string{"path", "start_line"},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
Impl: func(args map[string]interface{}) (string, error) {
|
|
||||||
tmp, ok := args["path"]
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
|
||||||
}
|
|
||||||
path, ok := tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
var start_line int
|
|
||||||
tmp, ok = args["start_line"]
|
|
||||||
if ok {
|
|
||||||
tmp, ok := tmp.(float64)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
start_line = int(tmp)
|
|
||||||
}
|
|
||||||
var end_line int
|
|
||||||
tmp, ok = args["end_line"]
|
|
||||||
if ok {
|
|
||||||
tmp, ok := tmp.(float64)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
end_line = int(tmp)
|
|
||||||
}
|
|
||||||
var content string
|
|
||||||
tmp, ok = args["content"]
|
|
||||||
if ok {
|
|
||||||
content, ok = tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return FileReplaceLines(path, start_line, end_line, content), nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func resultToJson(result FunctionResult) string {
|
|
||||||
if result.Message == "" {
|
|
||||||
// When message not supplied, assume success
|
|
||||||
result.Message = "success"
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(result)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
|
|
||||||
}
|
|
||||||
return string(jsonBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecuteToolCalls handles the execution of all tool_calls provided, and
|
|
||||||
// returns their results formatted as []Message(s) with role: 'tool' and.
|
|
||||||
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
|
|
||||||
var toolResults []Message
|
|
||||||
for _, toolCall := range toolCalls {
|
|
||||||
if toolCall.Type != "function" {
|
|
||||||
// unsupported tool type
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
tool, ok := AvailableTools[toolCall.Function.Name]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
var functionArgs map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: ability to silence this
|
|
||||||
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
|
|
||||||
|
|
||||||
// Execute the tool
|
|
||||||
toolResult, err := tool.Impl(functionArgs)
|
|
||||||
if err != nil {
|
|
||||||
// This can happen if the model missed or supplied invalid tool args
|
|
||||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResults = append(toolResults, Message{
|
|
||||||
Role: "tool",
|
|
||||||
OriginalContent: toolResult,
|
|
||||||
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
|
|
||||||
// name is not required since the introduction of ToolCallID
|
|
||||||
// hypothesis: by setting it, we inform the model of what a
|
|
||||||
// function's purpose was if future requests omit the function
|
|
||||||
// definition
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return toolResults, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPathContained attempts to verify whether `path` is the same as or
|
|
||||||
// contained within `directory`. It is overly cautious, returning false even if
|
|
||||||
// `path` IS contained within `directory`, but the two paths use different
|
|
||||||
// casing, and we happen to be on a case-insensitive filesystem.
|
|
||||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
|
||||||
// tell it to. Additional layers of security should be considered.. run in a
|
|
||||||
// VM/container.
|
|
||||||
func isPathContained(directory string, path string) (bool, error) {
|
|
||||||
// Clean and resolve symlinks for both paths
|
|
||||||
path, err := filepath.Abs(path)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if path exists
|
|
||||||
_, err = os.Stat(path)
|
|
||||||
if err != nil {
|
|
||||||
if !os.IsNotExist(err) {
|
|
||||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
path, err = filepath.EvalSymlinks(path)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
directory, err = filepath.Abs(directory)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
directory, err = filepath.EvalSymlinks(directory)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case insensitive checks
|
|
||||||
if !strings.EqualFold(path, directory) &&
|
|
||||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPathWithinCWD(path string) (bool, *FunctionResult) {
|
|
||||||
cwd, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
return false, &FunctionResult{Message: "Failed to determine current working directory"}
|
|
||||||
}
|
|
||||||
if ok, err := isPathContained(cwd, path); !ok {
|
|
||||||
if err != nil {
|
|
||||||
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
|
|
||||||
}
|
|
||||||
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadDir(path string) string {
|
|
||||||
// TODO(?): implement whitelist - list of directories which model is allowed to work in
|
|
||||||
if path == "" {
|
|
||||||
path = "."
|
|
||||||
}
|
|
||||||
ok, res := isPathWithinCWD(path)
|
|
||||||
if !ok {
|
|
||||||
return resultToJson(*res)
|
|
||||||
}
|
|
||||||
|
|
||||||
files, err := os.ReadDir(path)
|
|
||||||
if err != nil {
|
|
||||||
return resultToJson(FunctionResult{
|
|
||||||
Message: err.Error(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var dirContents []map[string]interface{}
|
|
||||||
for _, f := range files {
|
|
||||||
info, _ := f.Info()
|
|
||||||
|
|
||||||
name := f.Name()
|
|
||||||
if strings.HasPrefix(name, ".") {
|
|
||||||
// skip hidden files
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
entryType := "file"
|
|
||||||
size := info.Size()
|
|
||||||
|
|
||||||
if info.IsDir() {
|
|
||||||
name += "/"
|
|
||||||
entryType = "dir"
|
|
||||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
|
||||||
size = int64(len(subdirfiles))
|
|
||||||
}
|
|
||||||
|
|
||||||
dirContents = append(dirContents, map[string]interface{}{
|
|
||||||
"name": name,
|
|
||||||
"type": entryType,
|
|
||||||
"size": size,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return resultToJson(FunctionResult{Result: dirContents})
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadFile(path string) string {
|
|
||||||
ok, res := isPathWithinCWD(path)
|
|
||||||
if !ok {
|
|
||||||
return resultToJson(*res)
|
|
||||||
}
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
|
||||||
}
|
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\n")
|
|
||||||
content := strings.Builder{}
|
|
||||||
for i, line := range lines {
|
|
||||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
|
||||||
}
|
|
||||||
|
|
||||||
return resultToJson(FunctionResult{
|
|
||||||
Result: content.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteFile(path string, content string) string {
|
|
||||||
ok, res := isPathWithinCWD(path)
|
|
||||||
if !ok {
|
|
||||||
return resultToJson(*res)
|
|
||||||
}
|
|
||||||
err := os.WriteFile(path, []byte(content), 0644)
|
|
||||||
if err != nil {
|
|
||||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
|
||||||
}
|
|
||||||
return resultToJson(FunctionResult{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func FileInsertLines(path string, position int, content string) string {
|
|
||||||
ok, res := isPathWithinCWD(path)
|
|
||||||
if !ok {
|
|
||||||
return resultToJson(*res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the existing file's content
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
if !os.IsNotExist(err) {
|
|
||||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
|
||||||
}
|
|
||||||
_, err = os.Create(path)
|
|
||||||
if err != nil {
|
|
||||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
|
||||||
}
|
|
||||||
data = []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if position < 1 {
|
|
||||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
|
||||||
}
|
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\n")
|
|
||||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
|
||||||
|
|
||||||
before := lines[:position-1]
|
|
||||||
after := lines[position-1:]
|
|
||||||
lines = append(before, append(contentLines, after...)...)
|
|
||||||
|
|
||||||
newContent := strings.Join(lines, "\n")
|
|
||||||
|
|
||||||
// Join the lines and write back to the file
|
|
||||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
|
||||||
if err != nil {
|
|
||||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
|
||||||
}
|
|
||||||
|
|
||||||
return resultToJson(FunctionResult{Result: newContent})
|
|
||||||
}
|
|
||||||
|
|
||||||
func FileReplaceLines(path string, startLine int, endLine int, content string) string {
|
|
||||||
ok, res := isPathWithinCWD(path)
|
|
||||||
if !ok {
|
|
||||||
return resultToJson(*res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the existing file's content
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
if !os.IsNotExist(err) {
|
|
||||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
|
||||||
}
|
|
||||||
_, err = os.Create(path)
|
|
||||||
if err != nil {
|
|
||||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
|
||||||
}
|
|
||||||
data = []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if startLine < 1 {
|
|
||||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
|
||||||
}
|
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\n")
|
|
||||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
|
||||||
|
|
||||||
if endLine == 0 || endLine > len(lines) {
|
|
||||||
endLine = len(lines)
|
|
||||||
}
|
|
||||||
|
|
||||||
before := lines[:startLine-1]
|
|
||||||
after := lines[endLine:]
|
|
||||||
|
|
||||||
lines = append(before, append(contentLines, after...)...)
|
|
||||||
newContent := strings.Join(lines, "\n")
|
|
||||||
|
|
||||||
// Join the lines and write back to the file
|
|
||||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
|
||||||
if err != nil {
|
|
||||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
|
||||||
}
|
|
||||||
|
|
||||||
return resultToJson(FunctionResult{Result: newContent})
|
|
||||||
|
|
||||||
}
|
|
@ -1,187 +0,0 @@
|
|||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
|
||||||
)
|
|
||||||
|
|
||||||
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
|
|
||||||
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
|
||||||
for _, m := range messages {
|
|
||||||
message := openai.ChatCompletionMessage{
|
|
||||||
Role: string(m.Role),
|
|
||||||
Content: m.OriginalContent,
|
|
||||||
}
|
|
||||||
if m.ToolCallID.Valid {
|
|
||||||
message.ToolCallID = m.ToolCallID.String
|
|
||||||
}
|
|
||||||
if m.ToolCalls.Valid {
|
|
||||||
// unmarshal directly into chatMessage.ToolCalls
|
|
||||||
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: handle, this shouldn't really happen since
|
|
||||||
// we only save the successfully marshal'd data to database
|
|
||||||
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
chatCompletionMessages = append(chatCompletionMessages, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
request := openai.ChatCompletionRequest{
|
|
||||||
Model: model,
|
|
||||||
Messages: chatCompletionMessages,
|
|
||||||
MaxTokens: maxTokens,
|
|
||||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
|
||||||
}
|
|
||||||
|
|
||||||
var tools []openai.Tool
|
|
||||||
for _, t := range config.OpenAI.EnabledTools {
|
|
||||||
tool, ok := AvailableTools[t]
|
|
||||||
if ok {
|
|
||||||
tools = append(tools, tool.Tool)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tools) > 0 {
|
|
||||||
request.Tools = tools
|
|
||||||
request.ToolChoice = "auto"
|
|
||||||
}
|
|
||||||
|
|
||||||
return request
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
|
||||||
// response. CreateChatCompletion will recursively call itself in the case of
|
|
||||||
// tool calls, until a response is received with the final user-facing output.
|
|
||||||
func CreateChatCompletion(model string, messages []Message, maxTokens int, replies *[]Message) (string, error) {
|
|
||||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
|
||||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
choice := resp.Choices[0]
|
|
||||||
|
|
||||||
if len(choice.Message.ToolCalls) > 0 {
|
|
||||||
// Append the assistant's reply with its request for tool calls
|
|
||||||
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
|
||||||
assistantReply := Message{
|
|
||||||
Role: MessageRoleAssistant,
|
|
||||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
|
||||||
}
|
|
||||||
|
|
||||||
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if replies != nil {
|
|
||||||
*replies = append(append(*replies, assistantReply), toolReplies...)
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = append(append(messages, assistantReply), toolReplies...)
|
|
||||||
// Recurse into CreateChatCompletion with the tool call replies added
|
|
||||||
// to the original messages
|
|
||||||
return CreateChatCompletion(model, messages, maxTokens, replies)
|
|
||||||
}
|
|
||||||
|
|
||||||
if replies != nil {
|
|
||||||
*replies = append(*replies, Message{
|
|
||||||
Role: MessageRoleAssistant,
|
|
||||||
OriginalContent: choice.Message.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the user-facing message.
|
|
||||||
return choice.Message.Content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
|
||||||
// and both returns and streams the response to the provided output channel.
|
|
||||||
// May return a partial response if an error occurs mid-stream.
|
|
||||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string, replies *[]Message) (string, error) {
|
|
||||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
content := strings.Builder{}
|
|
||||||
toolCalls := []openai.ToolCall{}
|
|
||||||
|
|
||||||
// Iterate stream segments
|
|
||||||
for {
|
|
||||||
response, e := stream.Recv()
|
|
||||||
if errors.Is(e, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if e != nil {
|
|
||||||
err = e
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
delta := response.Choices[0].Delta
|
|
||||||
if len(delta.ToolCalls) > 0 {
|
|
||||||
// Construct streamed tool_call arguments
|
|
||||||
for _, tc := range delta.ToolCalls {
|
|
||||||
if tc.Index == nil {
|
|
||||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
|
||||||
}
|
|
||||||
if len(toolCalls) <= *tc.Index {
|
|
||||||
toolCalls = append(toolCalls, tc)
|
|
||||||
} else {
|
|
||||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
output <- delta.Content
|
|
||||||
content.WriteString(delta.Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
|
||||||
// Append the assistant's reply with its request for tool calls
|
|
||||||
toolCallJson, _ := json.Marshal(toolCalls)
|
|
||||||
|
|
||||||
assistantReply := Message{
|
|
||||||
Role: MessageRoleAssistant,
|
|
||||||
OriginalContent: content.String(),
|
|
||||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
|
||||||
}
|
|
||||||
|
|
||||||
toolReplies, err := ExecuteToolCalls(toolCalls)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if replies != nil {
|
|
||||||
*replies = append(append(*replies, assistantReply), toolReplies...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
|
||||||
// added to the original messages
|
|
||||||
messages = append(append(messages, assistantReply), toolReplies...)
|
|
||||||
return CreateChatCompletionStream(model, messages, maxTokens, output, replies)
|
|
||||||
}
|
|
||||||
|
|
||||||
if replies != nil {
|
|
||||||
*replies = append(*replies, Message{
|
|
||||||
Role: MessageRoleAssistant,
|
|
||||||
OriginalContent: content.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return content.String(), err
|
|
||||||
}
|
|
141
pkg/cli/store.go
141
pkg/cli/store.go
@ -1,141 +0,0 @@
|
|||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
sqids "github.com/sqids/sqids-go"
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Store struct {
|
|
||||||
db *gorm.DB
|
|
||||||
sqids *sqids.Sqids
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
ID uint `gorm:"primaryKey"`
|
|
||||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
|
||||||
Conversation Conversation
|
|
||||||
OriginalContent string
|
|
||||||
Role MessageRole // one of: 'system', 'user', 'assistant', 'tool'
|
|
||||||
CreatedAt time.Time
|
|
||||||
ToolCallID sql.NullString
|
|
||||||
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
|
||||||
}
|
|
||||||
|
|
||||||
type Conversation struct {
|
|
||||||
ID uint `gorm:"primaryKey"`
|
|
||||||
ShortName sql.NullString
|
|
||||||
Title string
|
|
||||||
}
|
|
||||||
|
|
||||||
func dataDir() string {
|
|
||||||
var dataDir string
|
|
||||||
|
|
||||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
|
||||||
if xdgDataHome != "" {
|
|
||||||
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
|
||||||
} else {
|
|
||||||
userHomeDir, _ := os.UserHomeDir()
|
|
||||||
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
|
||||||
}
|
|
||||||
|
|
||||||
os.MkdirAll(dataDir, 0755)
|
|
||||||
return dataDir
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStore() (*Store, error) {
|
|
||||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
|
||||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
models := []any{
|
|
||||||
&Conversation{},
|
|
||||||
&Message{},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, x := range models {
|
|
||||||
err := db.AutoMigrate(x)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
|
||||||
return &Store{db, _sqids}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) SaveConversation(conversation *Conversation) error {
|
|
||||||
err := s.db.Save(&conversation).Error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !conversation.ShortName.Valid {
|
|
||||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
|
||||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
|
||||||
err = s.db.Save(&conversation).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteConversation(conversation *Conversation) error {
|
|
||||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&Message{})
|
|
||||||
return s.db.Delete(&conversation).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) SaveMessage(message *Message) error {
|
|
||||||
return s.db.Create(message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteMessage(message *Message) error {
|
|
||||||
return s.db.Delete(&message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) Conversations() ([]Conversation, error) {
|
|
||||||
var conversations []Conversation
|
|
||||||
err := s.db.Find(&conversations).Error
|
|
||||||
return conversations, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) ConversationShortNameCompletions(shortName string) []string {
|
|
||||||
var completions []string
|
|
||||||
conversations, _ := s.Conversations() // ignore error for completions
|
|
||||||
for _, conversation := range conversations {
|
|
||||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
|
||||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return completions
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) ConversationByShortName(shortName string) (*Conversation, error) {
|
|
||||||
if shortName == "" {
|
|
||||||
return nil, errors.New("shortName is empty")
|
|
||||||
}
|
|
||||||
var conversation Conversation
|
|
||||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
|
||||||
return &conversation, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) Messages(conversation *Conversation) ([]Message, error) {
|
|
||||||
var messages []Message
|
|
||||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
|
||||||
return messages, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) LastMessage(conversation *Conversation) (*Message, error) {
|
|
||||||
var message Message
|
|
||||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
|
||||||
return &message, err
|
|
||||||
}
|
|
113
pkg/cli/tty.go
113
pkg/cli/tty.go
@ -1,113 +0,0 @@
|
|||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/alecthomas/chroma/v2/quick"
|
|
||||||
"github.com/gookit/color"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
|
|
||||||
// received on the signal channel. An empty string sent to the channel to
|
|
||||||
// noftify the caller that the animation has completed (carriage returned).
|
|
||||||
func ShowWaitAnimation(signal chan any) {
|
|
||||||
animationStep := 0
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case _ = <-signal:
|
|
||||||
fmt.Print("\r")
|
|
||||||
signal <- ""
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
modSix := animationStep % 6
|
|
||||||
if modSix == 3 || modSix == 0 {
|
|
||||||
fmt.Print("\r")
|
|
||||||
}
|
|
||||||
if modSix < 3 {
|
|
||||||
fmt.Print(".")
|
|
||||||
} else {
|
|
||||||
fmt.Print(" ")
|
|
||||||
}
|
|
||||||
animationStep++
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
|
||||||
// for content to be received on the provided channel. As soon as any (possibly
|
|
||||||
// chunked) content is received on the channel, the waiting animation is
|
|
||||||
// replaced by the content.
|
|
||||||
// Blocks until the channel is closed.
|
|
||||||
func ShowDelayedContent(content <-chan string) {
|
|
||||||
waitSignal := make(chan any)
|
|
||||||
go ShowWaitAnimation(waitSignal)
|
|
||||||
|
|
||||||
firstChunk := true
|
|
||||||
for chunk := range content {
|
|
||||||
if firstChunk {
|
|
||||||
// notify wait animation that we've received data
|
|
||||||
waitSignal <- ""
|
|
||||||
// wait for signal that wait animation has completed
|
|
||||||
<-waitSignal
|
|
||||||
firstChunk = false
|
|
||||||
}
|
|
||||||
fmt.Print(chunk)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RenderConversation renders the given messages to TTY, with optional space
|
|
||||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
|
||||||
// are printed immediately after the final message (1 if false, 2 if true)
|
|
||||||
func RenderConversation(messages []Message, spaceForResponse bool) {
|
|
||||||
l := len(messages)
|
|
||||||
for i, message := range messages {
|
|
||||||
message.RenderTTY()
|
|
||||||
if i < l-1 || spaceForResponse {
|
|
||||||
// print an additional space before the next message
|
|
||||||
fmt.Println()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HighlightMarkdown applies syntax highlighting to the provided markdown text
|
|
||||||
// and writes it to stdout.
|
|
||||||
func HighlightMarkdown(markdownText string) error {
|
|
||||||
return quick.Highlight(os.Stdout, markdownText, "md", *config.Chroma.Formatter, *config.Chroma.Style)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) RenderTTY() {
|
|
||||||
var messageAge string
|
|
||||||
if m.CreatedAt.IsZero() {
|
|
||||||
messageAge = "now"
|
|
||||||
} else {
|
|
||||||
now := time.Now()
|
|
||||||
messageAge = humanTimeElapsedSince(now.Sub(m.CreatedAt))
|
|
||||||
}
|
|
||||||
|
|
||||||
var roleStyle color.Style
|
|
||||||
switch m.Role {
|
|
||||||
case MessageRoleSystem:
|
|
||||||
roleStyle = color.Style{color.HiRed}
|
|
||||||
case MessageRoleUser:
|
|
||||||
roleStyle = color.Style{color.HiGreen}
|
|
||||||
case MessageRoleAssistant:
|
|
||||||
roleStyle = color.Style{color.HiBlue}
|
|
||||||
default:
|
|
||||||
roleStyle = color.Style{color.FgWhite}
|
|
||||||
}
|
|
||||||
roleStyle.Add(color.Bold)
|
|
||||||
|
|
||||||
headerColor := color.FgYellow
|
|
||||||
separator := headerColor.Sprint("===")
|
|
||||||
timestamp := headerColor.Sprint(messageAge)
|
|
||||||
role := roleStyle.Sprint(m.FriendlyRole())
|
|
||||||
|
|
||||||
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
|
||||||
if m.OriginalContent != "" {
|
|
||||||
HighlightMarkdown(m.OriginalContent)
|
|
||||||
fmt.Println()
|
|
||||||
}
|
|
||||||
}
|
|
72
pkg/cmd/clone.go
Normal file
72
pkg/cmd/clone.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "clone <conversation>",
|
||||||
|
Short: "Clone conversations",
|
||||||
|
Long: `Clones the provided conversation.`,
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
argCount := 1
|
||||||
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
shortName := args[0]
|
||||||
|
toClone, err := cmdutil.LookupConversationE(ctx, shortName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
messagesToCopy, err := ctx.Store.Messages(toClone)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
clone := &model.Conversation{
|
||||||
|
Title: toClone.Title + " - Clone",
|
||||||
|
}
|
||||||
|
if err := ctx.Store.SaveConversation(clone); err != nil {
|
||||||
|
return fmt.Errorf("Cloud not create clone: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
messageCnt := 0
|
||||||
|
for _, message := range messagesToCopy {
|
||||||
|
newMessage := message
|
||||||
|
newMessage.ConversationID = clone.ID
|
||||||
|
newMessage.ID = 0
|
||||||
|
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
|
||||||
|
errors = append(errors, err)
|
||||||
|
} else {
|
||||||
|
messageCnt++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
if len(args) != 0 {
|
||||||
|
return nil, compMode
|
||||||
|
}
|
||||||
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
93
pkg/cmd/cmd.go
Normal file
93
pkg/cmd/cmd.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
systemPromptFile string
|
||||||
|
)
|
||||||
|
|
||||||
|
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
var root = &cobra.Command{
|
||||||
|
Use: "lmcli <command> [flags]",
|
||||||
|
Long: `lmcli - Large Language Model CLI`,
|
||||||
|
SilenceErrors: true,
|
||||||
|
SilenceUsage: true,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
cmd.Usage()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
continueCmd := ContinueCmd(ctx)
|
||||||
|
cloneCmd := CloneCmd(ctx)
|
||||||
|
editCmd := EditCmd(ctx)
|
||||||
|
listCmd := ListCmd(ctx)
|
||||||
|
newCmd := NewCmd(ctx)
|
||||||
|
promptCmd := PromptCmd(ctx)
|
||||||
|
renameCmd := RenameCmd(ctx)
|
||||||
|
replyCmd := ReplyCmd(ctx)
|
||||||
|
retryCmd := RetryCmd(ctx)
|
||||||
|
rmCmd := RemoveCmd(ctx)
|
||||||
|
viewCmd := ViewCmd(ctx)
|
||||||
|
|
||||||
|
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
|
||||||
|
for _, cmd := range inputCmds {
|
||||||
|
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
|
||||||
|
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
||||||
|
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||||
|
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
||||||
|
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||||
|
}
|
||||||
|
|
||||||
|
renameCmd.Flags().Bool("generate", false, "Generate a conversation title")
|
||||||
|
|
||||||
|
root.AddCommand(
|
||||||
|
cloneCmd,
|
||||||
|
continueCmd,
|
||||||
|
editCmd,
|
||||||
|
listCmd,
|
||||||
|
newCmd,
|
||||||
|
promptCmd,
|
||||||
|
renameCmd,
|
||||||
|
replyCmd,
|
||||||
|
retryCmd,
|
||||||
|
rmCmd,
|
||||||
|
viewCmd,
|
||||||
|
)
|
||||||
|
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSystemPrompt(ctx *lmcli.Context) string {
|
||||||
|
if systemPromptFile != "" {
|
||||||
|
content, err := util.ReadFileContents(systemPromptFile)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
return *ctx.Config.Defaults.SystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
// inputFromArgsOrEditor returns either the provided input from the args slice
|
||||||
|
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||||
|
// whatever input was provided there. placeholder is a string which populates
|
||||||
|
// the editor and gets stripped from the final output.
|
||||||
|
func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) {
|
||||||
|
var err error
|
||||||
|
if len(args) == 0 {
|
||||||
|
message, err = util.InputFromEditor(placeholder, "message.*.md", existingMessage)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Failed to get input: %v\n", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
message = strings.Join(args, " ")
|
||||||
|
}
|
||||||
|
message = strings.Trim(message, " \t\n")
|
||||||
|
return
|
||||||
|
}
|
72
pkg/cmd/continue.go
Normal file
72
pkg/cmd/continue.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "continue <conversation>",
|
||||||
|
Short: "Continue a conversation from the last message",
|
||||||
|
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
argCount := 1
|
||||||
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
shortName := args[0]
|
||||||
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
|
messages, err := ctx.Store.Messages(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(messages) < 2 {
|
||||||
|
return fmt.Errorf("conversation expected to have at least 2 messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
lastMessage := &messages[len(messages)-1]
|
||||||
|
if lastMessage.Role != model.MessageRoleAssistant {
|
||||||
|
return fmt.Errorf("the last message in the conversation is not an assistant message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output the contents of the last message so far
|
||||||
|
fmt.Print(lastMessage.Content)
|
||||||
|
|
||||||
|
// Submit the LLM request, allowing it to continue the last message
|
||||||
|
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the new response to the original message
|
||||||
|
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
|
||||||
|
|
||||||
|
// Update the original message
|
||||||
|
err = ctx.Store.UpdateMessage(lastMessage)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not update the last message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
if len(args) != 0 {
|
||||||
|
return nil, compMode
|
||||||
|
}
|
||||||
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
100
pkg/cmd/edit.go
Normal file
100
pkg/cmd/edit.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "edit <conversation>",
|
||||||
|
Short: "Edit the last user reply in a conversation",
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
argCount := 1
|
||||||
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
shortName := args[0]
|
||||||
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
|
messages, err := ctx.Store.Messages(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
|
}
|
||||||
|
|
||||||
|
offset, _ := cmd.Flags().GetInt("offset")
|
||||||
|
if offset < 0 {
|
||||||
|
offset = -offset
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset > len(messages)-1 {
|
||||||
|
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
desiredIdx := len(messages) - 1 - offset
|
||||||
|
|
||||||
|
// walk backwards through the conversation deleting messages until and
|
||||||
|
// including the last user message
|
||||||
|
toRemove := []model.Message{}
|
||||||
|
var toEdit *model.Message
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if i == desiredIdx {
|
||||||
|
toEdit = &messages[i]
|
||||||
|
}
|
||||||
|
toRemove = append(toRemove, messages[i])
|
||||||
|
messages = messages[:i]
|
||||||
|
if toEdit != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||||
|
switch newContents {
|
||||||
|
case toEdit.Content:
|
||||||
|
return fmt.Errorf("No edits were made.")
|
||||||
|
case "":
|
||||||
|
return fmt.Errorf("No message was provided.")
|
||||||
|
}
|
||||||
|
|
||||||
|
role, _ := cmd.Flags().GetString("role")
|
||||||
|
if role == "" {
|
||||||
|
role = string(toEdit.Role)
|
||||||
|
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
||||||
|
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, message := range toRemove {
|
||||||
|
err = ctx.Store.DeleteMessage(&message)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not delete message: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||||
|
ConversationID: conversation.ID,
|
||||||
|
Role: model.MessageRole(role),
|
||||||
|
Content: newContents,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
if len(args) != 0 {
|
||||||
|
return nil, compMode
|
||||||
|
}
|
||||||
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
|
||||||
|
cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)")
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
122
pkg/cmd/list.go
Normal file
122
pkg/cmd/list.go
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
LS_COUNT int = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List conversations",
|
||||||
|
Long: `List conversations in order of recent activity`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
conversations, err := ctx.Store.Conversations()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Category struct {
|
||||||
|
name string
|
||||||
|
cutoff time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConversationLine struct {
|
||||||
|
timeSinceReply time.Duration
|
||||||
|
formatted string
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||||
|
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||||
|
dayOfWeek := int(now.Weekday())
|
||||||
|
categories := []Category{
|
||||||
|
{"today", now.Sub(midnight)},
|
||||||
|
{"yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
|
||||||
|
{"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
|
||||||
|
{"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
|
||||||
|
{"this month", now.Sub(monthStart)},
|
||||||
|
{"last month", now.Sub(monthStart.AddDate(0, -1, 0))},
|
||||||
|
{"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
|
||||||
|
{"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
|
||||||
|
{"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
|
||||||
|
{"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
|
||||||
|
{"older", now.Sub(time.Time{})},
|
||||||
|
}
|
||||||
|
categorized := map[string][]ConversationLine{}
|
||||||
|
|
||||||
|
all, _ := cmd.Flags().GetBool("all")
|
||||||
|
|
||||||
|
for _, conversation := range conversations {
|
||||||
|
lastMessage, err := ctx.Store.LastMessage(&conversation)
|
||||||
|
if lastMessage == nil || err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
messageAge := now.Sub(lastMessage.CreatedAt)
|
||||||
|
|
||||||
|
var category string
|
||||||
|
for _, c := range categories {
|
||||||
|
if messageAge < c.cutoff {
|
||||||
|
category = c.name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted := fmt.Sprintf(
|
||||||
|
"%s - %s - %s",
|
||||||
|
conversation.ShortName.String,
|
||||||
|
util.HumanTimeElapsedSince(messageAge),
|
||||||
|
conversation.Title,
|
||||||
|
)
|
||||||
|
|
||||||
|
categorized[category] = append(
|
||||||
|
categorized[category],
|
||||||
|
ConversationLine{messageAge, formatted},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
count, _ := cmd.Flags().GetInt("count")
|
||||||
|
var conversationsPrinted int
|
||||||
|
outer:
|
||||||
|
for _, category := range categories {
|
||||||
|
conversationLines, ok := categorized[category.name]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
|
||||||
|
return int(a.timeSinceReply - b.timeSinceReply)
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Printf("%s:\n", category.name)
|
||||||
|
for _, conv := range conversationLines {
|
||||||
|
if conversationsPrinted >= count && !all {
|
||||||
|
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
||||||
|
break outer
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(" %s\n", conv.formatted)
|
||||||
|
conversationsPrinted++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().Bool("all", false, "Show all conversations")
|
||||||
|
cmd.Flags().Int("count", LS_COUNT, "How many conversations to show")
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
60
pkg/cmd/new.go
Normal file
60
pkg/cmd/new.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "new [message]",
|
||||||
|
Short: "Start a new conversation",
|
||||||
|
Long: `Start a new conversation with the Large Language Model.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||||
|
if messageContents == "" {
|
||||||
|
return fmt.Errorf("No message was provided.")
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation := &model.Conversation{}
|
||||||
|
err := ctx.Store.SaveConversation(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not save new conversation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := []model.Message{
|
||||||
|
{
|
||||||
|
ConversationID: conversation.ID,
|
||||||
|
Role: model.MessageRoleSystem,
|
||||||
|
Content: getSystemPrompt(ctx),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ConversationID: conversation.ID,
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: messageContents,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
|
||||||
|
|
||||||
|
title, err := cmdutil.GenerateTitle(ctx, conversation)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not generate title for conversation: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation.Title = title
|
||||||
|
|
||||||
|
err = ctx.Store.SaveConversation(conversation)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not save conversation after generating title: %v\n", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
42
pkg/cmd/prompt.go
Normal file
42
pkg/cmd/prompt.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "prompt [message]",
|
||||||
|
Short: "Do a one-shot prompt",
|
||||||
|
Long: `Prompt the Large Language Model and get a response.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||||
|
if message == "" {
|
||||||
|
return fmt.Errorf("No message was provided.")
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := []model.Message{
|
||||||
|
{
|
||||||
|
Role: model.MessageRoleSystem,
|
||||||
|
Content: getSystemPrompt(ctx),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
60
pkg/cmd/remove.go
Normal file
60
pkg/cmd/remove.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "rm <conversation>...",
|
||||||
|
Short: "Remove conversations",
|
||||||
|
Long: `Remove conversations by their short names.`,
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
argCount := 1
|
||||||
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
var toRemove []*model.Conversation
|
||||||
|
for _, shortName := range args {
|
||||||
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
toRemove = append(toRemove, conversation)
|
||||||
|
}
|
||||||
|
var errors []error
|
||||||
|
for _, c := range toRemove {
|
||||||
|
err := ctx.Store.DeleteConversation(c)
|
||||||
|
if err != nil {
|
||||||
|
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return fmt.Errorf("Could not remove some conversations: %v", errors)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
var completions []string
|
||||||
|
outer:
|
||||||
|
for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) {
|
||||||
|
parts := strings.Split(completion, "\t")
|
||||||
|
for _, arg := range args {
|
||||||
|
if parts[0] == arg {
|
||||||
|
continue outer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
completions = append(completions, completion)
|
||||||
|
}
|
||||||
|
return completions, compMode
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
60
pkg/cmd/rename.go
Normal file
60
pkg/cmd/rename.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "rename <conversation> [title]",
|
||||||
|
Short: "Rename a conversation",
|
||||||
|
Long: `Renames a conversation, either with the provided title or by generating a new name.`,
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
argCount := 1
|
||||||
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
shortName := args[0]
|
||||||
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
var err error
|
||||||
|
|
||||||
|
generate, _ := cmd.Flags().GetBool("generate")
|
||||||
|
var title string
|
||||||
|
if generate {
|
||||||
|
title, err = cmdutil.GenerateTitle(ctx, conversation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return fmt.Errorf("Conversation title not provided.")
|
||||||
|
}
|
||||||
|
title = strings.Join(args[1:], " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation.Title = title
|
||||||
|
err = ctx.Store.SaveConversation(conversation)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not save conversation with new title: %v\n", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
if len(args) != 0 {
|
||||||
|
return nil, compMode
|
||||||
|
}
|
||||||
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
49
pkg/cmd/reply.go
Normal file
49
pkg/cmd/reply.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "reply <conversation> [message]",
|
||||||
|
Short: "Reply to a conversation",
|
||||||
|
Long: `Sends a reply to conversation and writes the response to stdout.`,
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
argCount := 1
|
||||||
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
shortName := args[0]
|
||||||
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
|
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
||||||
|
if reply == "" {
|
||||||
|
return fmt.Errorf("No reply was provided.")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||||
|
ConversationID: conversation.ID,
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: reply,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
if len(args) != 0 {
|
||||||
|
return nil, compMode
|
||||||
|
}
|
||||||
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
58
pkg/cmd/retry.go
Normal file
58
pkg/cmd/retry.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "retry <conversation>",
|
||||||
|
Short: "Retry the last user reply in a conversation",
|
||||||
|
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
argCount := 1
|
||||||
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
shortName := args[0]
|
||||||
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
|
messages, err := ctx.Store.Messages(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
|
}
|
||||||
|
|
||||||
|
// walk backwards through the conversation and delete messages, break
|
||||||
|
// when we find the latest user response
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if messages[i].Role == model.MessageRoleUser {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ctx.Store.DeleteMessage(&messages[i])
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not delete previous reply: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdutil.HandleConversationReply(ctx, conversation, true)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
if len(args) != 0 {
|
||||||
|
return nil, compMode
|
||||||
|
}
|
||||||
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
284
pkg/cmd/util/util.go
Normal file
284
pkg/cmd/util/util.go
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
|
"github.com/alecthomas/chroma/v2/quick"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||||
|
// the response to stdout. Returns all model reply messages.
|
||||||
|
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
|
||||||
|
content := make(chan string) // receives the reponse from LLM
|
||||||
|
defer close(content)
|
||||||
|
|
||||||
|
// render all content received over the channel
|
||||||
|
go ShowDelayedContent(content)
|
||||||
|
|
||||||
|
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolBag []model.Tool
|
||||||
|
for _, toolName := range *ctx.Config.Tools.EnabledTools {
|
||||||
|
tool, ok := tools.AvailableTools[toolName]
|
||||||
|
if ok {
|
||||||
|
toolBag = append(toolBag, tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
requestParams := model.RequestParameters{
|
||||||
|
Model: *ctx.Config.Defaults.Model,
|
||||||
|
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||||
|
Temperature: *ctx.Config.Defaults.Temperature,
|
||||||
|
ToolBag: toolBag,
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiReplies []model.Message
|
||||||
|
response, err := completionProvider.CreateChatCompletionStream(
|
||||||
|
requestParams, messages, &apiReplies, content,
|
||||||
|
)
|
||||||
|
if response != "" {
|
||||||
|
// there was some content, so break to a new line after it
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Received partial response. Error: %v\n", err)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return apiReplies, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupConversation either returns the conversation found by the
|
||||||
|
// short name or exits the program
|
||||||
|
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
|
||||||
|
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||||
|
}
|
||||||
|
if c.ID == 0 {
|
||||||
|
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
|
||||||
|
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||||
|
}
|
||||||
|
if c.ID == 0 {
|
||||||
|
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConversationReply handles sending messages to an existing
|
||||||
|
// conversation, optionally persisting both the sent replies and responses.
|
||||||
|
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
||||||
|
existing, err := ctx.Store.Messages(c)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
||||||
|
}
|
||||||
|
|
||||||
|
if persist {
|
||||||
|
for _, message := range toSend {
|
||||||
|
err = ctx.Store.SaveMessage(&message)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allMessages := append(existing, toSend...)
|
||||||
|
|
||||||
|
RenderConversation(ctx, allMessages, true)
|
||||||
|
|
||||||
|
// render a message header with no contents
|
||||||
|
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||||
|
|
||||||
|
replies, err := FetchAndShowCompletion(ctx, allMessages)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if persist {
|
||||||
|
for _, reply := range replies {
|
||||||
|
reply.ConversationID = c.ID
|
||||||
|
|
||||||
|
err = ctx.Store.SaveMessage(&reply)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not save reply: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
||||||
|
sb := strings.Builder{}
|
||||||
|
for _, message := range messages {
|
||||||
|
if message.Role != model.MessageRoleUser && (message.Role != model.MessageRoleSystem || !system) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf("<%s>\n", message.Role.FriendlyRole()))
|
||||||
|
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.Content))
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
||||||
|
messages, err := ctx.Store.Messages(c)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
const header = "Generate a concise 4-5 word title for the conversation below."
|
||||||
|
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, FormatForExternalPrompt(messages, false))
|
||||||
|
|
||||||
|
generateRequest := []model.Message{
|
||||||
|
{
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: prompt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestParams := model.RequestParameters{
|
||||||
|
Model: *ctx.Config.Conversations.TitleGenerationModel,
|
||||||
|
MaxTokens: 25,
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
||||||
|
// received on the signal channel. An empty string sent to the channel to
|
||||||
|
// noftify the caller that the animation has completed (carriage returned).
|
||||||
|
func ShowWaitAnimation(signal chan any) {
|
||||||
|
// Save the current cursor position
|
||||||
|
fmt.Print("\033[s")
|
||||||
|
|
||||||
|
animationStep := 0
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case _ = <-signal:
|
||||||
|
// Relmcli the cursor position
|
||||||
|
fmt.Print("\033[u")
|
||||||
|
signal <- ""
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Move the cursor to the saved position
|
||||||
|
modSix := animationStep % 6
|
||||||
|
if modSix == 3 || modSix == 0 {
|
||||||
|
fmt.Print("\033[u")
|
||||||
|
}
|
||||||
|
if modSix < 3 {
|
||||||
|
fmt.Print(".")
|
||||||
|
} else {
|
||||||
|
fmt.Print(" ")
|
||||||
|
}
|
||||||
|
animationStep++
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
||||||
|
// for content to be received on the provided channel. As soon as any (possibly
|
||||||
|
// chunked) content is received on the channel, the waiting animation is
|
||||||
|
// replaced by the content.
|
||||||
|
// Blocks until the channel is closed.
|
||||||
|
func ShowDelayedContent(content <-chan string) {
|
||||||
|
waitSignal := make(chan any)
|
||||||
|
go ShowWaitAnimation(waitSignal)
|
||||||
|
|
||||||
|
firstChunk := true
|
||||||
|
for chunk := range content {
|
||||||
|
if firstChunk {
|
||||||
|
// notify wait animation that we've received data
|
||||||
|
waitSignal <- ""
|
||||||
|
// wait for signal that wait animation has completed
|
||||||
|
<-waitSignal
|
||||||
|
firstChunk = false
|
||||||
|
}
|
||||||
|
fmt.Print(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenderConversation renders the given messages to TTY, with optional space
|
||||||
|
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||||
|
// are printed immediately after the final message (1 if false, 2 if true)
|
||||||
|
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
|
||||||
|
l := len(messages)
|
||||||
|
for i, message := range messages {
|
||||||
|
RenderMessage(ctx, &message)
|
||||||
|
if i < l-1 || spaceForResponse {
|
||||||
|
// print an additional space before the next message
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HighlightMarkdown applies syntax highlighting to the provided markdown text
|
||||||
|
// and writes it to stdout.
|
||||||
|
func HighlightMarkdown(w io.Writer, markdownText string, formatter string, style string) error {
|
||||||
|
return quick.Highlight(w, markdownText, "md", formatter, style)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
||||||
|
var messageAge string
|
||||||
|
if m.CreatedAt.IsZero() {
|
||||||
|
messageAge = "now"
|
||||||
|
} else {
|
||||||
|
now := time.Now()
|
||||||
|
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
headerStyle := lipgloss.NewStyle().Bold(true)
|
||||||
|
|
||||||
|
switch m.Role {
|
||||||
|
case model.MessageRoleSystem:
|
||||||
|
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
||||||
|
case model.MessageRoleUser:
|
||||||
|
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
||||||
|
case model.MessageRoleAssistant:
|
||||||
|
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
||||||
|
}
|
||||||
|
|
||||||
|
role := headerStyle.Render(m.Role.FriendlyRole())
|
||||||
|
|
||||||
|
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
|
||||||
|
separator := separatorStyle.Render("===")
|
||||||
|
timestamp := separatorStyle.Render(messageAge)
|
||||||
|
|
||||||
|
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
||||||
|
if m.Content != "" {
|
||||||
|
HighlightMarkdown(
|
||||||
|
os.Stdout, m.Content,
|
||||||
|
*ctx.Config.Chroma.Formatter,
|
||||||
|
*ctx.Config.Chroma.Style,
|
||||||
|
)
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
}
|
45
pkg/cmd/view.go
Normal file
45
pkg/cmd/view.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "view <conversation>",
|
||||||
|
Short: "View messages in a conversation",
|
||||||
|
Long: `Finds a conversation by its short name and displays its contents.`,
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
argCount := 1
|
||||||
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
shortName := args[0]
|
||||||
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
|
messages, err := ctx.Store.Messages(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdutil.RenderConversation(ctx, messages, false)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
if len(args) != 0 {
|
||||||
|
return nil, compMode
|
||||||
|
}
|
||||||
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
@ -1,46 +1,41 @@
|
|||||||
package cli
|
package lmcli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"github.com/go-yaml/yaml"
|
"github.com/go-yaml/yaml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ModelDefaults *struct {
|
Defaults *struct {
|
||||||
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
||||||
} `yaml:"modelDefaults"`
|
MaxTokens *int `yaml:"maxTokens" default:"256"`
|
||||||
|
Temperature *float32 `yaml:"temperature" default:"0.7"`
|
||||||
|
Model *string `yaml:"model" default:"gpt-4"`
|
||||||
|
} `yaml:"defaults"`
|
||||||
|
Conversations *struct {
|
||||||
|
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
||||||
|
} `yaml:"conversations"`
|
||||||
|
Tools *struct {
|
||||||
|
EnabledTools *[]string `yaml:"enabledTools"`
|
||||||
|
} `yaml:"tools"`
|
||||||
OpenAI *struct {
|
OpenAI *struct {
|
||||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||||
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
|
Models *[]string `yaml:"models"`
|
||||||
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
|
|
||||||
EnabledTools []string `yaml:"enabledTools"`
|
|
||||||
} `yaml:"openai"`
|
} `yaml:"openai"`
|
||||||
|
Anthropic *struct {
|
||||||
|
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||||
|
Models *[]string `yaml:"models"`
|
||||||
|
} `yaml:"anthropic"`
|
||||||
Chroma *struct {
|
Chroma *struct {
|
||||||
Style *string `yaml:"style" default:"onedark"`
|
Style *string `yaml:"style" default:"onedark"`
|
||||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||||
} `yaml:"chroma"`
|
} `yaml:"chroma"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func configDir() string {
|
func NewConfig(configFile string) (*Config, error) {
|
||||||
var configDir string
|
|
||||||
|
|
||||||
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
|
||||||
if xdgConfigHome != "" {
|
|
||||||
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
|
||||||
} else {
|
|
||||||
userHomeDir, _ := os.UserHomeDir()
|
|
||||||
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
|
||||||
}
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0755)
|
|
||||||
return configDir
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConfig() (*Config, error) {
|
|
||||||
configFile := filepath.Join(configDir(), "config.yaml")
|
|
||||||
shouldWriteDefaults := false
|
shouldWriteDefaults := false
|
||||||
c := &Config{}
|
c := &Config{}
|
||||||
|
|
||||||
@ -54,11 +49,11 @@ func NewConfig() (*Config, error) {
|
|||||||
yaml.Unmarshal(configBytes, c)
|
yaml.Unmarshal(configBytes, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldWriteDefaults = SetStructDefaults(c)
|
shouldWriteDefaults = util.SetStructDefaults(c)
|
||||||
if !configExists || shouldWriteDefaults {
|
if !configExists || shouldWriteDefaults {
|
||||||
if configExists {
|
if configExists {
|
||||||
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile + ".bak")
|
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak")
|
||||||
os.Rename(configFile, configFile + ".bak")
|
os.Rename(configFile, configFile+".bak")
|
||||||
}
|
}
|
||||||
fmt.Printf("Writing configuration file to %s\n", configFile)
|
fmt.Printf("Writing configuration file to %s\n", configFile)
|
||||||
file, err := os.Create(configFile)
|
file, err := os.Create(configFile)
|
97
pkg/lmcli/lmcli.go
Normal file
97
pkg/lmcli/lmcli.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package lmcli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Context struct {
|
||||||
|
Config Config
|
||||||
|
Store ConversationStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContext() (*Context, error) {
|
||||||
|
configFile := filepath.Join(configDir(), "config.yaml")
|
||||||
|
config, err := NewConfig(configFile)
|
||||||
|
if err != nil {
|
||||||
|
Fatal("%v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||||
|
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||||
|
}
|
||||||
|
s, err := NewSQLStore(db)
|
||||||
|
if err != nil {
|
||||||
|
Fatal("%v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Context{*config, s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
||||||
|
for _, m := range *c.Config.Anthropic.Models {
|
||||||
|
if m == model {
|
||||||
|
anthropic := &anthropic.AnthropicClient{
|
||||||
|
APIKey: *c.Config.Anthropic.APIKey,
|
||||||
|
}
|
||||||
|
return anthropic, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range *c.Config.OpenAI.Models {
|
||||||
|
if m == model {
|
||||||
|
openai := &openai.OpenAIClient{
|
||||||
|
APIKey: *c.Config.OpenAI.APIKey,
|
||||||
|
}
|
||||||
|
return openai, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unknown model: %s", model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func configDir() string {
|
||||||
|
var configDir string
|
||||||
|
|
||||||
|
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||||
|
if xdgConfigHome != "" {
|
||||||
|
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
||||||
|
} else {
|
||||||
|
userHomeDir, _ := os.UserHomeDir()
|
||||||
|
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
||||||
|
}
|
||||||
|
|
||||||
|
os.MkdirAll(configDir, 0755)
|
||||||
|
return configDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func dataDir() string {
|
||||||
|
var dataDir string
|
||||||
|
|
||||||
|
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||||
|
if xdgDataHome != "" {
|
||||||
|
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
||||||
|
} else {
|
||||||
|
userHomeDir, _ := os.UserHomeDir()
|
||||||
|
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
||||||
|
}
|
||||||
|
|
||||||
|
os.MkdirAll(dataDir, 0755)
|
||||||
|
return dataDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func Fatal(format string, args ...any) {
|
||||||
|
fmt.Fprintf(os.Stderr, format, args...)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Warn(format string, args ...any) {
|
||||||
|
fmt.Fprintf(os.Stderr, format, args...)
|
||||||
|
}
|
58
pkg/lmcli/model/conversation.go
Normal file
58
pkg/lmcli/model/conversation.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MessageRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageRoleSystem MessageRole = "system"
|
||||||
|
MessageRoleUser MessageRole = "user"
|
||||||
|
MessageRoleAssistant MessageRole = "assistant"
|
||||||
|
MessageRoleToolCall MessageRole = "tool_call"
|
||||||
|
MessageRoleToolResult MessageRole = "tool_result"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||||
|
Content string
|
||||||
|
Role MessageRole
|
||||||
|
CreatedAt time.Time
|
||||||
|
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
||||||
|
ToolResults ToolResults // a json array of tool results
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conversation struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
ShortName sql.NullString
|
||||||
|
Title string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestParameters struct {
|
||||||
|
Model string
|
||||||
|
MaxTokens int
|
||||||
|
Temperature float32
|
||||||
|
TopP float32
|
||||||
|
|
||||||
|
SystemPrompt string
|
||||||
|
ToolBag []Tool
|
||||||
|
}
|
||||||
|
|
||||||
|
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||||
|
func (m *MessageRole) FriendlyRole() string {
|
||||||
|
var friendlyRole string
|
||||||
|
switch *m {
|
||||||
|
case MessageRoleUser:
|
||||||
|
friendlyRole = "You"
|
||||||
|
case MessageRoleSystem:
|
||||||
|
friendlyRole = "System"
|
||||||
|
case MessageRoleAssistant:
|
||||||
|
friendlyRole = "Assistant"
|
||||||
|
default:
|
||||||
|
friendlyRole = string(*m)
|
||||||
|
}
|
||||||
|
return friendlyRole
|
||||||
|
}
|
98
pkg/lmcli/model/tool.go
Normal file
98
pkg/lmcli/model/tool.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Parameters []ToolParameter
|
||||||
|
Impl func(*Tool, map[string]interface{}) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"` // "string", "integer", "boolean"
|
||||||
|
Required bool `json:"required"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Parameters map[string]interface{} `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCalls []ToolCall
|
||||||
|
|
||||||
|
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||||
|
s := value.(string)
|
||||||
|
if value == nil || s == "" {
|
||||||
|
*tc = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal([]byte(s), tc)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc ToolCalls) Value() (driver.Value, error) {
|
||||||
|
if len(tc) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(tc)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolResult struct {
|
||||||
|
ToolCallID string `json:"toolCallID"`
|
||||||
|
ToolName string `json:"toolName,omitempty"`
|
||||||
|
Result string `json:"result,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolResults []ToolResult
|
||||||
|
|
||||||
|
func (tr *ToolResults) Scan(value any) (err error) {
|
||||||
|
s := value.(string)
|
||||||
|
if value == nil || s == "" {
|
||||||
|
*tr = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal([]byte(s), tr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr ToolResults) Value() (driver.Value, error) {
|
||||||
|
if len(tr) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal([]ToolResult(tr))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type CallResult struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Result any `json:"result,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r CallResult) ToJson() (string, error) {
|
||||||
|
if r.Message == "" {
|
||||||
|
// When message not supplied, assume success
|
||||||
|
r.Message = "success"
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
322
pkg/lmcli/provider/anthropic/anthropic.go
Normal file
322
pkg/lmcli/provider/anthropic/anthropic.go
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AnthropicClient struct {
|
||||||
|
APIKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
OriginalContent string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
//TopP float32 `json:"top_p,omitempty"`
|
||||||
|
//TopK float32 `json:"top_k,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OriginalContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
OriginalContent []OriginalContent `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||||
|
|
||||||
|
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
||||||
|
requestBody := Request{
|
||||||
|
Model: params.Model,
|
||||||
|
Messages: make([]Message, len(messages)),
|
||||||
|
System: params.SystemPrompt,
|
||||||
|
MaxTokens: params.MaxTokens,
|
||||||
|
Temperature: params.Temperature,
|
||||||
|
Stream: false,
|
||||||
|
|
||||||
|
StopSequences: []string{
|
||||||
|
FUNCTION_STOP_SEQUENCE,
|
||||||
|
"\n\nHuman:",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
startIdx := 0
|
||||||
|
if messages[0].Role == model.MessageRoleSystem {
|
||||||
|
requestBody.System = messages[0].Content
|
||||||
|
requestBody.Messages = requestBody.Messages[:len(messages)-1]
|
||||||
|
startIdx = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.ToolBag) > 0 {
|
||||||
|
if len(requestBody.System) > 0 {
|
||||||
|
// add a divider between existing system prompt and tools
|
||||||
|
requestBody.System += "\n\n---\n\n"
|
||||||
|
}
|
||||||
|
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg := range messages[startIdx:] {
|
||||||
|
message := &requestBody.Messages[i]
|
||||||
|
|
||||||
|
switch msg.Role {
|
||||||
|
case model.MessageRoleToolCall:
|
||||||
|
message.Role = "assistant"
|
||||||
|
message.OriginalContent = msg.Content
|
||||||
|
//message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||||
|
case model.MessageRoleToolResult:
|
||||||
|
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||||
|
xmlString, err := xmlFuncResults.XMLString()
|
||||||
|
if err != nil {
|
||||||
|
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
||||||
|
}
|
||||||
|
message.Role = "user"
|
||||||
|
message.OriginalContent = xmlString
|
||||||
|
default:
|
||||||
|
message.Role = string(msg.Role)
|
||||||
|
message.OriginalContent = msg.Content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return requestBody
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
|
||||||
|
url := "https://api.anthropic.com/v1/messages"
|
||||||
|
|
||||||
|
jsonBody, err := json.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("x-api-key", c.APIKey)
|
||||||
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
req.Header.Set("content-type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AnthropicClient) CreateChatCompletion(
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
replies *[]model.Message,
|
||||||
|
) (string, error) {
|
||||||
|
request := buildRequest(params, messages)
|
||||||
|
|
||||||
|
resp, err := sendRequest(c, request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var response Response
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sb := strings.Builder{}
|
||||||
|
for _, content := range response.OriginalContent {
|
||||||
|
var reply model.Message
|
||||||
|
switch content.Type {
|
||||||
|
case "text":
|
||||||
|
reply = model.Message{
|
||||||
|
Role: model.MessageRoleAssistant,
|
||||||
|
Content: content.Text,
|
||||||
|
}
|
||||||
|
sb.WriteString(reply.Content)
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||||
|
}
|
||||||
|
*replies = append(*replies, reply)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
replies *[]model.Message,
|
||||||
|
output chan<- string,
|
||||||
|
) (string, error) {
|
||||||
|
request := buildRequest(params, messages)
|
||||||
|
request.Stream = true
|
||||||
|
|
||||||
|
resp, err := sendRequest(c, request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
sb := strings.Builder{}
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if line[0] == '{' {
|
||||||
|
var event map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(line), &event)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||||
|
}
|
||||||
|
eventType, ok := event["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("invalid event: %s", line)
|
||||||
|
}
|
||||||
|
switch eventType {
|
||||||
|
case "error":
|
||||||
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
|
default:
|
||||||
|
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(line, "data:") {
|
||||||
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
|
var event map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(data), &event)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
eventType, ok := event["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("invalid event type")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch eventType {
|
||||||
|
case "message_start":
|
||||||
|
// noop
|
||||||
|
case "ping":
|
||||||
|
// write an empty string to signal start of text
|
||||||
|
output <- ""
|
||||||
|
case "content_block_start":
|
||||||
|
// ignore?
|
||||||
|
case "content_block_delta":
|
||||||
|
delta, ok := event["delta"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("invalid content block delta")
|
||||||
|
}
|
||||||
|
text, ok := delta["text"].(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("invalid text delta")
|
||||||
|
}
|
||||||
|
sb.WriteString(text)
|
||||||
|
output <- text
|
||||||
|
case "content_block_stop":
|
||||||
|
// ignore?
|
||||||
|
case "message_delta":
|
||||||
|
delta, ok := event["delta"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("invalid message delta")
|
||||||
|
}
|
||||||
|
stopReason, ok := delta["stop_reason"].(string)
|
||||||
|
if ok && stopReason == "stop_sequence" {
|
||||||
|
stopSequence, ok := delta["stop_sequence"].(string)
|
||||||
|
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
||||||
|
content := sb.String()
|
||||||
|
|
||||||
|
start := strings.Index(content, "<function_calls>")
|
||||||
|
if start == -1 {
|
||||||
|
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||||
|
}
|
||||||
|
|
||||||
|
funcCallXml := content[start:]
|
||||||
|
funcCallXml += FUNCTION_STOP_SEQUENCE
|
||||||
|
|
||||||
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||||
|
output <- FUNCTION_STOP_SEQUENCE
|
||||||
|
|
||||||
|
// Extract function calls
|
||||||
|
var functionCalls XMLFunctionCalls
|
||||||
|
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute function calls
|
||||||
|
toolCall := model.Message{
|
||||||
|
Role: model.MessageRoleToolCall,
|
||||||
|
Content: content,
|
||||||
|
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
toolReply := model.Message{
|
||||||
|
Role: model.MessageRoleToolResult,
|
||||||
|
ToolResults: toolResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(append(*replies, toolCall), toolReply)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
|
// added to the original messages
|
||||||
|
messages = append(append(messages, toolCall), toolReply)
|
||||||
|
return c.CreateChatCompletionStream(params, messages, replies, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "message_stop":
|
||||||
|
// return the completed message
|
||||||
|
reply := model.Message{
|
||||||
|
Role: model.MessageRoleAssistant,
|
||||||
|
Content: sb.String(),
|
||||||
|
}
|
||||||
|
*replies = append(*replies, reply)
|
||||||
|
return sb.String(), nil
|
||||||
|
case "error":
|
||||||
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
|
default:
|
||||||
|
fmt.Printf("\nUnrecognized event: %s\n", data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("unexpected end of stream")
|
||||||
|
}
|
182
pkg/lmcli/provider/anthropic/tools.go
Normal file
182
pkg/lmcli/provider/anthropic/tools.go
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const TOOL_PREAMBLE = `In this environment you have access to a set of tools which may assist you in fulfilling user requests.
|
||||||
|
|
||||||
|
You may call them like this:
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>$TOOL_NAME</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||||
|
...
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
|
||||||
|
Here are the tools available:`
|
||||||
|
|
||||||
|
type XMLTools struct {
|
||||||
|
XMLName struct{} `xml:"tools"`
|
||||||
|
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLToolDescription struct {
|
||||||
|
ToolName string `xml:"tool_name"`
|
||||||
|
Description string `xml:"description"`
|
||||||
|
Parameters []XMLToolParameter `xml:"parameters>parameter"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLToolParameter struct {
|
||||||
|
Name string `xml:"name"`
|
||||||
|
Type string `xml:"type"`
|
||||||
|
Description string `xml:"description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionCalls struct {
|
||||||
|
XMLName struct{} `xml:"function_calls"`
|
||||||
|
Invoke []XMLFunctionInvoke `xml:"invoke"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionInvoke struct {
|
||||||
|
ToolName string `xml:"tool_name"`
|
||||||
|
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionInvokeParameters struct {
|
||||||
|
String string `xml:",innerxml"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionResults struct {
|
||||||
|
XMLName struct{} `xml:"function_results"`
|
||||||
|
Result []XMLFunctionResult `xml:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionResult struct {
|
||||||
|
ToolName string `xml:"tool_name"`
|
||||||
|
Stdout string `xml:"stdout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
|
||||||
|
// parameters name to value
|
||||||
|
func parseFunctionParametersXML(params string) map[string]interface{} {
|
||||||
|
lines := strings.Split(params, "\n")
|
||||||
|
ret := make(map[string]interface{}, len(lines))
|
||||||
|
for _, line := range lines {
|
||||||
|
i := strings.Index(line, ">")
|
||||||
|
if i == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
j := strings.Index(line, "</")
|
||||||
|
if j == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// chop from after opening < to first > to get parameter name,
|
||||||
|
// then chop after > to first </ to get parameter value
|
||||||
|
ret[line[1:i]] = line[i+1 : j]
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
||||||
|
converted := make([]XMLToolDescription, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
converted[i].ToolName = tool.Name
|
||||||
|
converted[i].Description = tool.Description
|
||||||
|
|
||||||
|
params := make([]XMLToolParameter, len(tool.Parameters))
|
||||||
|
for j, param := range tool.Parameters {
|
||||||
|
params[j].Name = param.Name
|
||||||
|
params[j].Description = param.Description
|
||||||
|
params[j].Type = param.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
converted[i].Parameters = params
|
||||||
|
}
|
||||||
|
return XMLTools{
|
||||||
|
ToolDescriptions: converted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
|
||||||
|
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
|
||||||
|
for i, invoke := range functionCalls.Invoke {
|
||||||
|
toolCalls[i].Name = invoke.ToolName
|
||||||
|
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
||||||
|
}
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
||||||
|
converted := make([]XMLFunctionResult, len(toolResults))
|
||||||
|
for i, result := range toolResults {
|
||||||
|
converted[i].ToolName = result.ToolName
|
||||||
|
converted[i].Stdout = result.Result
|
||||||
|
}
|
||||||
|
return XMLFunctionResults{
|
||||||
|
Result: converted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildToolsSystemPrompt(tools []model.Tool) string {
|
||||||
|
xmlTools := convertToolsToXMLTools(tools)
|
||||||
|
xmlToolsString, err := xmlTools.XMLString()
|
||||||
|
if err != nil {
|
||||||
|
panic("Could not serialize []model.Tool to XMLTools")
|
||||||
|
}
|
||||||
|
return TOOL_PREAMBLE + "\n" + xmlToolsString + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x XMLTools) XMLString() (string, error) {
|
||||||
|
tmpl, err := template.New("tools").Parse(`<tools>
|
||||||
|
{{range .ToolDescriptions}}<tool_description>
|
||||||
|
<tool_name>{{.ToolName}}</tool_name>
|
||||||
|
<description>
|
||||||
|
{{.Description}}
|
||||||
|
</description>
|
||||||
|
<parameters>
|
||||||
|
{{range .Parameters}}<parameter>
|
||||||
|
<name>{{.Name}}</name>
|
||||||
|
<type>{{.Type}}</type>
|
||||||
|
<description>{{.Description}}</description>
|
||||||
|
</parameter>
|
||||||
|
{{end}}</parameters>
|
||||||
|
</tool_description>
|
||||||
|
{{end}}</tools>`)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&buf, x); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x XMLFunctionResults) XMLString() (string, error) {
|
||||||
|
tmpl, err := template.New("function_results").Parse(`<function_results>
|
||||||
|
{{range .Result}}<result>
|
||||||
|
<tool_name>{{.ToolName}}</tool_name>
|
||||||
|
<stdout>{{.Stdout}}</stdout>
|
||||||
|
</result>
|
||||||
|
{{end}}</function_results>`)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&buf, x); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
270
pkg/lmcli/provider/openai/openai.go
Normal file
270
pkg/lmcli/provider/openai/openai.go
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
|
openai "github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIClient struct {
|
||||||
|
APIKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertTools(tools []model.Tool) []openai.Tool {
|
||||||
|
openaiTools := make([]openai.Tool, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
openaiTools[i].Type = "function"
|
||||||
|
|
||||||
|
params := make(map[string]OpenAIToolParameter)
|
||||||
|
var required []string
|
||||||
|
|
||||||
|
for _, param := range tool.Parameters {
|
||||||
|
params[param.Name] = OpenAIToolParameter{
|
||||||
|
Type: param.Type,
|
||||||
|
Description: param.Description,
|
||||||
|
Enum: param.Enum,
|
||||||
|
}
|
||||||
|
if param.Required {
|
||||||
|
required = append(required, param.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
openaiTools[i].Function = openai.FunctionDefinition{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
Parameters: OpenAIToolParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: params,
|
||||||
|
Required: required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return openaiTools
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
|
||||||
|
converted := make([]openai.ToolCall, len(toolCalls))
|
||||||
|
for i, call := range toolCalls {
|
||||||
|
converted[i].Type = "function"
|
||||||
|
converted[i].ID = call.ID
|
||||||
|
converted[i].Function.Name = call.Name
|
||||||
|
|
||||||
|
json, _ := json.Marshal(call.Parameters)
|
||||||
|
converted[i].Function.Arguments = string(json)
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
|
||||||
|
converted := make([]model.ToolCall, len(toolCalls))
|
||||||
|
for i, call := range toolCalls {
|
||||||
|
converted[i].ID = call.ID
|
||||||
|
converted[i].Name = call.Function.Name
|
||||||
|
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func createChatCompletionRequest(
|
||||||
|
c *OpenAIClient,
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
) openai.ChatCompletionRequest {
|
||||||
|
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
|
||||||
|
|
||||||
|
for _, m := range messages {
|
||||||
|
switch m.Role {
|
||||||
|
case "tool_call":
|
||||||
|
message := openai.ChatCompletionMessage{}
|
||||||
|
message.Role = "assistant"
|
||||||
|
message.Content = m.Content
|
||||||
|
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||||
|
requestMessages = append(requestMessages, message)
|
||||||
|
case "tool_result":
|
||||||
|
// expand tool_result messages' results into multiple openAI messages
|
||||||
|
for _, result := range m.ToolResults {
|
||||||
|
message := openai.ChatCompletionMessage{}
|
||||||
|
message.Role = "tool"
|
||||||
|
message.Content = result.Result
|
||||||
|
message.ToolCallID = result.ToolCallID
|
||||||
|
requestMessages = append(requestMessages, message)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
message := openai.ChatCompletionMessage{}
|
||||||
|
message.Role = string(m.Role)
|
||||||
|
message.Content = m.Content
|
||||||
|
requestMessages = append(requestMessages, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request := openai.ChatCompletionRequest{
|
||||||
|
Model: params.Model,
|
||||||
|
MaxTokens: params.MaxTokens,
|
||||||
|
Temperature: params.Temperature,
|
||||||
|
Messages: requestMessages,
|
||||||
|
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.ToolBag) > 0 {
|
||||||
|
request.Tools = convertTools(params.ToolBag)
|
||||||
|
request.ToolChoice = "auto"
|
||||||
|
}
|
||||||
|
|
||||||
|
return request
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleToolCalls(
|
||||||
|
params model.RequestParameters,
|
||||||
|
content string,
|
||||||
|
toolCalls []openai.ToolCall,
|
||||||
|
) ([]model.Message, error) {
|
||||||
|
toolCall := model.Message{
|
||||||
|
Role: model.MessageRoleToolCall,
|
||||||
|
Content: content,
|
||||||
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResult := model.Message{
|
||||||
|
Role: model.MessageRoleToolResult,
|
||||||
|
ToolResults: toolResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
return []model.Message{toolCall, toolResult}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIClient) CreateChatCompletion(
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
replies *[]model.Message,
|
||||||
|
) (string, error) {
|
||||||
|
client := openai.NewClient(c.APIKey)
|
||||||
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
|
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := resp.Choices[0]
|
||||||
|
|
||||||
|
toolCalls := choice.Message.ToolCalls
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if results != nil {
|
||||||
|
*replies = append(*replies, results...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse into CreateChatCompletion with the tool call replies
|
||||||
|
messages = append(messages, results...)
|
||||||
|
return c.CreateChatCompletion(params, messages, replies)
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(*replies, model.Message{
|
||||||
|
Role: model.MessageRoleAssistant,
|
||||||
|
Content: choice.Message.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the user-facing message.
|
||||||
|
return choice.Message.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
replies *[]model.Message,
|
||||||
|
output chan<- string,
|
||||||
|
) (string, error) {
|
||||||
|
client := openai.NewClient(c.APIKey)
|
||||||
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
content := strings.Builder{}
|
||||||
|
toolCalls := []openai.ToolCall{}
|
||||||
|
|
||||||
|
// Iterate stream segments
|
||||||
|
for {
|
||||||
|
response, e := stream.Recv()
|
||||||
|
if errors.Is(e, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if e != nil {
|
||||||
|
err = e
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := response.Choices[0].Delta
|
||||||
|
if len(delta.ToolCalls) > 0 {
|
||||||
|
// Construct streamed tool_call arguments
|
||||||
|
for _, tc := range delta.ToolCalls {
|
||||||
|
if tc.Index == nil {
|
||||||
|
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||||
|
}
|
||||||
|
if len(toolCalls) <= *tc.Index {
|
||||||
|
toolCalls = append(toolCalls, tc)
|
||||||
|
} else {
|
||||||
|
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
output <- delta.Content
|
||||||
|
content.WriteString(delta.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
results, err := handleToolCalls(params, content.String(), toolCalls)
|
||||||
|
if err != nil {
|
||||||
|
return content.String(), err
|
||||||
|
}
|
||||||
|
if results != nil {
|
||||||
|
*replies = append(*replies, results...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
|
messages = append(messages, results...)
|
||||||
|
return c.CreateChatCompletionStream(params, messages, replies, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
if replies != nil {
|
||||||
|
*replies = append(*replies, model.Message{
|
||||||
|
Role: model.MessageRoleAssistant,
|
||||||
|
Content: content.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content.String(), err
|
||||||
|
}
|
23
pkg/lmcli/provider/provider.go
Normal file
23
pkg/lmcli/provider/provider.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
|
||||||
|
type ChatCompletionClient interface {
|
||||||
|
// CreateChatCompletion requests a response to the provided messages.
|
||||||
|
// Replies are appended to the given replies struct, and the
|
||||||
|
// complete user-facing response is returned as a string.
|
||||||
|
CreateChatCompletion(
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
replies *[]model.Message,
|
||||||
|
) (string, error)
|
||||||
|
|
||||||
|
// Like CreateChageCompletion, except the response is streamed via
|
||||||
|
// the output channel as it's received.
|
||||||
|
CreateChatCompletionStream(
|
||||||
|
params model.RequestParameters,
|
||||||
|
messages []model.Message,
|
||||||
|
replies *[]model.Message,
|
||||||
|
output chan<- string,
|
||||||
|
) (string, error)
|
||||||
|
}
|
121
pkg/lmcli/store.go
Normal file
121
pkg/lmcli/store.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package lmcli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
sqids "github.com/sqids/sqids-go"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConversationStore interface {
|
||||||
|
Conversations() ([]model.Conversation, error)
|
||||||
|
|
||||||
|
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||||
|
ConversationShortNameCompletions(search string) []string
|
||||||
|
|
||||||
|
SaveConversation(conversation *model.Conversation) error
|
||||||
|
DeleteConversation(conversation *model.Conversation) error
|
||||||
|
|
||||||
|
Messages(conversation *model.Conversation) ([]model.Message, error)
|
||||||
|
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
||||||
|
|
||||||
|
SaveMessage(message *model.Message) error
|
||||||
|
DeleteMessage(message *model.Message) error
|
||||||
|
UpdateMessage(message *model.Message) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SQLStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
sqids *sqids.Sqids
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||||
|
models := []any{
|
||||||
|
&model.Conversation{},
|
||||||
|
&model.Message{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, x := range models {
|
||||||
|
err := db.AutoMigrate(x)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||||
|
return &SQLStore{db, _sqids}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
||||||
|
err := s.db.Save(&conversation).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !conversation.ShortName.Valid {
|
||||||
|
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||||
|
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||||
|
err = s.db.Save(&conversation).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
||||||
|
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
||||||
|
return s.db.Delete(&conversation).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
||||||
|
return s.db.Create(message).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
||||||
|
return s.db.Delete(&message).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
||||||
|
return s.db.Updates(&message).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
||||||
|
var conversations []model.Conversation
|
||||||
|
err := s.db.Find(&conversations).Error
|
||||||
|
return conversations, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||||
|
var completions []string
|
||||||
|
conversations, _ := s.Conversations() // ignore error for completions
|
||||||
|
for _, conversation := range conversations {
|
||||||
|
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||||
|
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return completions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
|
||||||
|
if shortName == "" {
|
||||||
|
return nil, errors.New("shortName is empty")
|
||||||
|
}
|
||||||
|
var conversation model.Conversation
|
||||||
|
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||||
|
return &conversation, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
||||||
|
var messages []model.Message
|
||||||
|
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||||
|
return messages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
||||||
|
var message model.Message
|
||||||
|
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||||
|
return &message, err
|
||||||
|
}
|
114
pkg/lmcli/tools/file_insert_lines.go
Normal file
114
pkg/lmcli/tools/file_insert_lines.go
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||||
|
|
||||||
|
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||||
|
|
||||||
|
var FileInsertLinesTool = model.Tool{
|
||||||
|
Name: "file_insert_lines",
|
||||||
|
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||||
|
Parameters: []model.ToolParameter{
|
||||||
|
{
|
||||||
|
Name: "path",
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "position",
|
||||||
|
Type: "integer",
|
||||||
|
Description: `Which line to insert content *before*.`,
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "content",
|
||||||
|
Type: "string",
|
||||||
|
Description: `The content to insert.`,
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
var position int
|
||||||
|
tmp, ok = args["position"]
|
||||||
|
if ok {
|
||||||
|
tmp, ok := tmp.(float64)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
position = int(tmp)
|
||||||
|
}
|
||||||
|
var content string
|
||||||
|
tmp, ok = args["content"]
|
||||||
|
if ok {
|
||||||
|
content, ok = tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := fileInsertLines(path, position, content)
|
||||||
|
ret, err := result.ToJson()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileInsertLines(path string, position int, content string) model.CallResult {
|
||||||
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return model.CallResult{Message: reason}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the existing file's content
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
|
}
|
||||||
|
_, err = os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||||
|
}
|
||||||
|
data = []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if position < 1 {
|
||||||
|
return model.CallResult{Message: "start_line cannot be less than 1"}
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||||
|
|
||||||
|
before := lines[:position-1]
|
||||||
|
after := lines[position-1:]
|
||||||
|
lines = append(before, append(contentLines, after...)...)
|
||||||
|
|
||||||
|
newContent := strings.Join(lines, "\n")
|
||||||
|
|
||||||
|
// Join the lines and write back to the file
|
||||||
|
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.CallResult{Result: newContent}
|
||||||
|
}
|
133
pkg/lmcli/tools/file_replace_lines.go
Normal file
133
pkg/lmcli/tools/file_replace_lines.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||||
|
|
||||||
|
Useful for re-writing snippets/blocks of code or entire functions.
|
||||||
|
|
||||||
|
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
|
||||||
|
|
||||||
|
var FileReplaceLinesTool = model.Tool{
|
||||||
|
Name: "file_replace_lines",
|
||||||
|
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||||
|
Parameters: []model.ToolParameter{
|
||||||
|
{
|
||||||
|
Name: "path",
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "start_line",
|
||||||
|
Type: "integer",
|
||||||
|
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "end_line",
|
||||||
|
Type: "integer",
|
||||||
|
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "content",
|
||||||
|
Type: "string",
|
||||||
|
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
var start_line int
|
||||||
|
tmp, ok = args["start_line"]
|
||||||
|
if ok {
|
||||||
|
tmp, ok := tmp.(float64)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
start_line = int(tmp)
|
||||||
|
}
|
||||||
|
var end_line int
|
||||||
|
tmp, ok = args["end_line"]
|
||||||
|
if ok {
|
||||||
|
tmp, ok := tmp.(float64)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
end_line = int(tmp)
|
||||||
|
}
|
||||||
|
var content string
|
||||||
|
tmp, ok = args["content"]
|
||||||
|
if ok {
|
||||||
|
content, ok = tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := fileReplaceLines(path, start_line, end_line, content)
|
||||||
|
ret, err := result.ToJson()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
|
||||||
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return model.CallResult{Message: reason}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the existing file's content
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
|
}
|
||||||
|
_, err = os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||||
|
}
|
||||||
|
data = []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startLine < 1 {
|
||||||
|
return model.CallResult{Message: "start_line cannot be less than 1"}
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||||
|
|
||||||
|
if endLine == 0 || endLine > len(lines) {
|
||||||
|
endLine = len(lines)
|
||||||
|
}
|
||||||
|
|
||||||
|
before := lines[:startLine-1]
|
||||||
|
after := lines[endLine:]
|
||||||
|
|
||||||
|
lines = append(before, append(contentLines, after...)...)
|
||||||
|
newContent := strings.Join(lines, "\n")
|
||||||
|
|
||||||
|
// Join the lines and write back to the file
|
||||||
|
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.CallResult{Result: newContent}
|
||||||
|
}
|
100
pkg/lmcli/tools/read_dir.go
Normal file
100
pkg/lmcli/tools/read_dir.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||||
|
|
||||||
|
Example result:
|
||||||
|
{
|
||||||
|
"message": "success",
|
||||||
|
"result": [
|
||||||
|
{"name": "a_file.txt", "type": "file", "size": 123},
|
||||||
|
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
For files, size represents the size of the file, in bytes.
|
||||||
|
For directories, size represents the number of entries in that directory.`
|
||||||
|
|
||||||
|
var ReadDirTool = model.Tool{
|
||||||
|
Name: "read_dir",
|
||||||
|
Description: READ_DIR_DESCRIPTION,
|
||||||
|
Parameters: []model.ToolParameter{
|
||||||
|
{
|
||||||
|
Name: "relative_dir",
|
||||||
|
Type: "string",
|
||||||
|
Description: "If set, read the contents of a directory relative to the current one.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||||
|
var relativeDir string
|
||||||
|
tmp, ok := args["relative_dir"]
|
||||||
|
if ok {
|
||||||
|
relativeDir, ok = tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result := readDir(relativeDir)
|
||||||
|
ret, err := result.ToJson()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func readDir(path string) model.CallResult {
|
||||||
|
if path == "" {
|
||||||
|
path = "."
|
||||||
|
}
|
||||||
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return model.CallResult{Message: reason}
|
||||||
|
}
|
||||||
|
|
||||||
|
files, err := os.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return model.CallResult{
|
||||||
|
Message: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var dirContents []map[string]interface{}
|
||||||
|
for _, f := range files {
|
||||||
|
info, _ := f.Info()
|
||||||
|
|
||||||
|
name := f.Name()
|
||||||
|
if strings.HasPrefix(name, ".") {
|
||||||
|
// skip hidden files
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
entryType := "file"
|
||||||
|
size := info.Size()
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
name += "/"
|
||||||
|
entryType = "dir"
|
||||||
|
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||||
|
size = int64(len(subdirfiles))
|
||||||
|
}
|
||||||
|
|
||||||
|
dirContents = append(dirContents, map[string]interface{}{
|
||||||
|
"name": name,
|
||||||
|
"type": entryType,
|
||||||
|
"size": size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.CallResult{Result: dirContents}
|
||||||
|
}
|
71
pkg/lmcli/tools/read_file.go
Normal file
71
pkg/lmcli/tools/read_file.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||||
|
|
||||||
|
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||||
|
|
||||||
|
Example result:
|
||||||
|
{
|
||||||
|
"message": "success",
|
||||||
|
"result": "1\tthe contents\n2\tof the file\n"
|
||||||
|
}`
|
||||||
|
|
||||||
|
var ReadFileTool = model.Tool{
|
||||||
|
Name: "read_file",
|
||||||
|
Description: READ_FILE_DESCRIPTION,
|
||||||
|
Parameters: []model.ToolParameter{
|
||||||
|
{
|
||||||
|
Name: "path",
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path to a file within the current working directory to read.",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
result := readFile(path)
|
||||||
|
ret, err := result.ToJson()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func readFile(path string) model.CallResult {
|
||||||
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return model.CallResult{Message: reason}
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
content := strings.Builder{}
|
||||||
|
for i, line := range lines {
|
||||||
|
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.CallResult{
|
||||||
|
Result: content.String(),
|
||||||
|
}
|
||||||
|
}
|
51
pkg/lmcli/tools/tools.go
Normal file
51
pkg/lmcli/tools/tools.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
var AvailableTools map[string]model.Tool = map[string]model.Tool{
|
||||||
|
"read_dir": ReadDirTool,
|
||||||
|
"read_file": ReadFileTool,
|
||||||
|
"write_file": WriteFileTool,
|
||||||
|
"file_insert_lines": FileInsertLinesTool,
|
||||||
|
"file_replace_lines": FileReplaceLinesTool,
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
|
||||||
|
var toolResults []model.ToolResult
|
||||||
|
for _, toolCall := range toolCalls {
|
||||||
|
var tool *model.Tool
|
||||||
|
for _, available := range toolBag {
|
||||||
|
if available.Name == toolCall.Name {
|
||||||
|
tool = &available
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tool == nil {
|
||||||
|
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: ability to silence this
|
||||||
|
fmt.Fprintf(os.Stderr, "\nINFO: Executing tool '%s' with args %s\n", toolCall.Name, toolCall.Parameters)
|
||||||
|
|
||||||
|
// Execute the tool
|
||||||
|
result, err := tool.Impl(tool, toolCall.Parameters)
|
||||||
|
if err != nil {
|
||||||
|
// This can happen if the model missed or supplied invalid tool args
|
||||||
|
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResult := model.ToolResult{
|
||||||
|
ToolCallID: toolCall.ID,
|
||||||
|
ToolName: toolCall.Name,
|
||||||
|
Result: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults = append(toolResults, toolResult)
|
||||||
|
}
|
||||||
|
return toolResults, nil
|
||||||
|
}
|
67
pkg/lmcli/tools/util/util.go
Normal file
67
pkg/lmcli/tools/util/util.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// isPathContained attempts to verify whether `path` is the same as or
|
||||||
|
// contained within `directory`. It is overly cautious, returning false even if
|
||||||
|
// `path` IS contained within `directory`, but the two paths use different
|
||||||
|
// casing, and we happen to be on a case-insensitive filesystem.
|
||||||
|
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||||
|
// tell it to. Additional layers of security should be considered.. run in a
|
||||||
|
// VM/container.
|
||||||
|
func IsPathContained(directory string, path string) (bool, error) {
|
||||||
|
// Clean and resolve symlinks for both paths
|
||||||
|
path, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if path exists
|
||||||
|
_, err = os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
path, err = filepath.EvalSymlinks(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
directory, err = filepath.Abs(directory)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
directory, err = filepath.EvalSymlinks(directory)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case insensitive checks
|
||||||
|
if !strings.EqualFold(path, directory) &&
|
||||||
|
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsPathWithinCWD(path string) (bool, string) {
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return false, "Failed to determine current working directory"
|
||||||
|
}
|
||||||
|
if ok, err := IsPathContained(cwd, path); !ok {
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())
|
||||||
|
}
|
||||||
|
return false, fmt.Sprintf("Path '%s' is not within the current working directory", path)
|
||||||
|
}
|
||||||
|
return true, ""
|
||||||
|
}
|
71
pkg/lmcli/tools/write_file.go
Normal file
71
pkg/lmcli/tools/write_file.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||||
|
|
||||||
|
Example result:
|
||||||
|
{
|
||||||
|
"message": "success"
|
||||||
|
}`
|
||||||
|
|
||||||
|
var WriteFileTool = model.Tool{
|
||||||
|
Name: "write_file",
|
||||||
|
Description: WRITE_FILE_DESCRIPTION,
|
||||||
|
Parameters: []model.ToolParameter{
|
||||||
|
{
|
||||||
|
Name: "path",
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path to a file within the current working directory to write to.",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "content",
|
||||||
|
Type: "string",
|
||||||
|
Description: "The content to write to the file. Overwrites any existing content!",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
tmp, ok = args["content"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||||
|
}
|
||||||
|
content, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
result := writeFile(path, content)
|
||||||
|
ret, err := result.ToJson()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFile(path string, content string) model.CallResult {
|
||||||
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return model.CallResult{Message: reason}
|
||||||
|
}
|
||||||
|
err := os.WriteFile(path, []byte(content), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
|
}
|
||||||
|
return model.CallResult{}
|
||||||
|
}
|
60
pkg/util/tty/highlight.go
Normal file
60
pkg/util/tty/highlight.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package tty
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/alecthomas/chroma/v2"
|
||||||
|
"github.com/alecthomas/chroma/v2/formatters"
|
||||||
|
"github.com/alecthomas/chroma/v2/lexers"
|
||||||
|
"github.com/alecthomas/chroma/v2/styles"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChromaHighlighter struct {
|
||||||
|
lexer chroma.Lexer
|
||||||
|
formatter chroma.Formatter
|
||||||
|
style *chroma.Style
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChromaHighlighter(lang, format, style string) *ChromaHighlighter {
|
||||||
|
l := lexers.Get(lang)
|
||||||
|
if l == nil {
|
||||||
|
l = lexers.Fallback
|
||||||
|
}
|
||||||
|
l = chroma.Coalesce(l)
|
||||||
|
|
||||||
|
f := formatters.Get(format)
|
||||||
|
if f == nil {
|
||||||
|
f = formatters.Fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
s := styles.Get(style)
|
||||||
|
if s == nil {
|
||||||
|
s = styles.Fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ChromaHighlighter{
|
||||||
|
lexer: l,
|
||||||
|
formatter: f,
|
||||||
|
style: s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChromaHighlighter) Highlight(w io.Writer, text string) error {
|
||||||
|
it, err := s.lexer.Tokenise(nil, text)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.formatter.Format(w, s.style, it)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChromaHighlighter) HighlightS(text string) (string, error) {
|
||||||
|
it, err := s.lexer.Tokenise(nil, text)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sb := strings.Builder{}
|
||||||
|
sb.Grow(len(text) * 2)
|
||||||
|
s.formatter.Format(&sb, s.style, it)
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package cli
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -56,7 +56,7 @@ func InputFromEditor(placeholder string, pattern string, content string) (string
|
|||||||
|
|
||||||
// humanTimeElapsedSince returns a human-friendly "in the past" representation
|
// humanTimeElapsedSince returns a human-friendly "in the past" representation
|
||||||
// of the given duration.
|
// of the given duration.
|
||||||
func humanTimeElapsedSince(d time.Duration) string {
|
func HumanTimeElapsedSince(d time.Duration) string {
|
||||||
seconds := d.Seconds()
|
seconds := d.Seconds()
|
||||||
minutes := seconds / 60
|
minutes := seconds / 60
|
||||||
hours := minutes / 60
|
hours := minutes / 60
|
||||||
@ -151,6 +151,14 @@ func SetStructDefaults(data interface{}) bool {
|
|||||||
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
||||||
field.Set(reflect.New(e))
|
field.Set(reflect.New(e))
|
||||||
field.Elem().SetInt(intValue)
|
field.Elem().SetInt(intValue)
|
||||||
|
case reflect.Float32:
|
||||||
|
floatValue, _ := strconv.ParseFloat(defaultTag, 32)
|
||||||
|
field.Set(reflect.New(e))
|
||||||
|
field.Elem().SetFloat(floatValue)
|
||||||
|
case reflect.Float64:
|
||||||
|
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
|
||||||
|
field.Set(reflect.New(e))
|
||||||
|
field.Elem().SetFloat(floatValue)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
boolValue := defaultTag == "true"
|
boolValue := defaultTag == "true"
|
||||||
field.Set(reflect.ValueOf(&boolValue))
|
field.Set(reflect.ValueOf(&boolValue))
|
||||||
@ -160,10 +168,8 @@ func SetStructDefaults(data interface{}) bool {
|
|||||||
return changed
|
return changed
|
||||||
}
|
}
|
||||||
|
|
||||||
// FileContents returns the string contents of the given file.
|
// ReadFileContents returns the string contents of the given file.
|
||||||
// TODO: we should support retrieving the content (or an approximation of)
|
func ReadFileContents(file string) (string, error) {
|
||||||
// non-text documents, e.g. PDFs.
|
|
||||||
func FileContents(file string) (string, error) {
|
|
||||||
path := filepath.Clean(file)
|
path := filepath.Clean(file)
|
||||||
content, err := os.ReadFile(path)
|
content, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
Loading…
Reference in New Issue
Block a user