move marshaller and unmarshaler into internal pkg (#304) (#325)

This commit is contained in:
渡邉祐一 / Yuichi Watanabe
2023-05-28 10:51:07 +09:00
committed by GitHub
parent 980504b47e
commit 62eb4beed2
12 changed files with 57 additions and 46 deletions

View File

@@ -4,6 +4,8 @@ import (
"bufio" "bufio"
"context" "context"
"net/http" "net/http"
utils "github.com/sashabaranov/go-openai/internal"
) )
type ChatCompletionStreamChoiceDelta struct { type ChatCompletionStreamChoiceDelta struct {
@@ -65,7 +67,7 @@ func (c *Client) CreateChatCompletionStream(
reader: bufio.NewReader(resp.Body), reader: bufio.NewReader(resp.Body),
response: resp, response: resp,
errAccumulator: newErrorAccumulator(), errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &utils.JSONUnmarshaler{},
}, },
} }
return return

View File

@@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
utils "github.com/sashabaranov/go-openai/internal"
) )
type errorAccumulator interface { type errorAccumulator interface {
@@ -19,13 +21,13 @@ type errorBuffer interface {
type defaultErrorAccumulator struct { type defaultErrorAccumulator struct {
buffer errorBuffer buffer errorBuffer
unmarshaler unmarshaler unmarshaler utils.Unmarshaler
} }
func newErrorAccumulator() errorAccumulator { func newErrorAccumulator() errorAccumulator {
return &defaultErrorAccumulator{ return &defaultErrorAccumulator{
buffer: &bytes.Buffer{}, buffer: &bytes.Buffer{},
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &utils.JSONUnmarshaler{},
} }
} }
@@ -42,7 +44,7 @@ func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
return return
} }
err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp) err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp)
if err != nil { if err != nil {
errResp = nil errResp = nil
} }

View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"testing" "testing"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
@@ -33,7 +34,7 @@ func (b *failingErrorBuffer) Bytes() []byte {
return []byte{} return []byte{}
} }
func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error { func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
return errTestUnmarshalerFailed return errTestUnmarshalerFailed
} }
@@ -62,7 +63,7 @@ func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
func TestErrorByteWriteErrors(t *testing.T) { func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &defaultErrorAccumulator{ accumulator := &defaultErrorAccumulator{
buffer: &failingErrorBuffer{}, buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &utils.JSONUnmarshaler{},
} }
err := accumulator.write([]byte("{")) err := accumulator.write([]byte("{"))
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
@@ -91,7 +92,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
stream.errAccumulator = &defaultErrorAccumulator{ stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{}, buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &utils.JSONUnmarshaler{},
} }
_, err = stream.Recv() _, err = stream.Recv()

View File

@@ -1,7 +1,7 @@
package openai //nolint:testpackage // testing private field package openai //nolint:testpackage // testing private field
import ( import (
. "github.com/sashabaranov/go-openai/internal" utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
@@ -86,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config.BaseURL = "" config.BaseURL = ""
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{} mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) FormBuilder { client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder return mockBuilder
} }

15
internal/marshaller.go Normal file
View File

@@ -0,0 +1,15 @@
package openai
import (
"encoding/json"
)
type Marshaller interface {
Marshal(value any) ([]byte, error)
}
type JSONMarshaller struct{}
func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

15
internal/unmarshaler.go Normal file
View File

@@ -0,0 +1,15 @@
package openai
import (
"encoding/json"
)
type Unmarshaler interface {
Unmarshal(data []byte, v any) error
}
type JSONUnmarshaler struct{}
func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}

View File

@@ -1,15 +0,0 @@
package openai
import (
"encoding/json"
)
type marshaller interface {
marshal(value any) ([]byte, error)
}
type jsonMarshaller struct{}
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

View File

@@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"context" "context"
"net/http" "net/http"
utils "github.com/sashabaranov/go-openai/internal"
) )
type requestBuilder interface { type requestBuilder interface {
@@ -11,12 +13,12 @@ type requestBuilder interface {
} }
type httpRequestBuilder struct { type httpRequestBuilder struct {
marshaller marshaller marshaller utils.Marshaller
} }
func newRequestBuilder() *httpRequestBuilder { func newRequestBuilder() *httpRequestBuilder {
return &httpRequestBuilder{ return &httpRequestBuilder{
marshaller: &jsonMarshaller{}, marshaller: &utils.JSONMarshaller{},
} }
} }
@@ -26,7 +28,7 @@ func (b *httpRequestBuilder) build(ctx context.Context, method, url string, requ
} }
var reqBytes []byte var reqBytes []byte
reqBytes, err := b.marshaller.marshal(request) reqBytes, err := b.marshaller.Marshal(request)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -19,7 +19,7 @@ type (
failingMarshaller struct{} failingMarshaller struct{}
) )
func (*failingMarshaller) marshal(_ any) ([]byte, error) { func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed return []byte{}, errTestMarshallerFailed
} }

View File

@@ -5,6 +5,8 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
utils "github.com/sashabaranov/go-openai/internal"
) )
var ( var (
@@ -54,7 +56,7 @@ func (c *Client) CreateCompletionStream(
reader: bufio.NewReader(resp.Body), reader: bufio.NewReader(resp.Body),
response: resp, response: resp,
errAccumulator: newErrorAccumulator(), errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &utils.JSONUnmarshaler{},
}, },
} }
return return

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
utils "github.com/sashabaranov/go-openai/internal"
) )
type streamable interface { type streamable interface {
@@ -19,7 +21,7 @@ type streamReader[T streamable] struct {
reader *bufio.Reader reader *bufio.Reader
response *http.Response response *http.Response
errAccumulator errorAccumulator errAccumulator errorAccumulator
unmarshaler unmarshaler unmarshaler utils.Unmarshaler
} }
func (stream *streamReader[T]) Recv() (response T, err error) { func (stream *streamReader[T]) Recv() (response T, err error) {
@@ -63,7 +65,7 @@ waitForData:
return return
} }
err = stream.unmarshaler.unmarshal(line, &response) err = stream.unmarshaler.Unmarshal(line, &response)
return return
} }

View File

@@ -1,15 +0,0 @@
package openai
import (
"encoding/json"
)
type unmarshaler interface {
unmarshal(data []byte, v any) error
}
type jsonUnmarshaler struct{}
func (jm *jsonUnmarshaler) unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}