62 lines
1.4 KiB
Go
62 lines
1.4 KiB
Go
|
package api
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"io"
|
||
|
|
||
|
"github.com/sashabaranov/go-openai"
|
||
|
)
|
||
|
|
||
|
// Client wraps OpenAI API client
|
||
|
type Client struct {
|
||
|
client *openai.Client
|
||
|
model string
|
||
|
}
|
||
|
|
||
|
// NewClient creates a new OpenAI API client
|
||
|
func NewClient(apiKey, model string) *Client {
|
||
|
openaiConfig := openai.DefaultConfig(apiKey)
|
||
|
openaiConfig.BaseURL = "https://api.msh.team/v1"
|
||
|
|
||
|
return &Client{
|
||
|
client: openai.NewClientWithConfig(openaiConfig),
|
||
|
model: model,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SetModel changes the model used for requests
|
||
|
func (c *Client) SetModel(model string) {
|
||
|
c.model = model
|
||
|
}
|
||
|
|
||
|
// SendChatCompletion sends a chat completion request and returns a stream
|
||
|
func (c *Client) SendChatCompletion(messages []openai.ChatCompletionMessage) (*openai.ChatCompletionStream, error) {
|
||
|
req := openai.ChatCompletionRequest{
|
||
|
Model: c.model,
|
||
|
Messages: messages,
|
||
|
Stream: true,
|
||
|
}
|
||
|
|
||
|
return c.client.CreateChatCompletionStream(context.Background(), req)
|
||
|
}
|
||
|
|
||
|
// GetNextResponse gets the next response from a stream
|
||
|
func (c *Client) GetNextResponse(stream *openai.ChatCompletionStream) (openai.ChatCompletionStreamResponse, error) {
|
||
|
resp, err := stream.Recv()
|
||
|
if err != nil {
|
||
|
if errors.Is(err, io.EOF) {
|
||
|
return openai.ChatCompletionStreamResponse{}, io.EOF
|
||
|
}
|
||
|
return openai.ChatCompletionStreamResponse{}, err
|
||
|
}
|
||
|
return resp, nil
|
||
|
}
|
||
|
|
||
|
// CloseStream closes a stream
|
||
|
func (c *Client) CloseStream(stream *openai.ChatCompletionStream) {
|
||
|
if stream != nil {
|
||
|
_ = stream.Close()
|
||
|
}
|
||
|
}
|