Improve handling of JSON Schema in OpenAI API Response Context (#819)

* feat: add jsonschema.Validate and jsonschema.Unmarshal

* fix Sanity check

* remove slices.Contains

* fix Sanity check

* add SchemaWrapper

* update api_integration_test.go

* update method 'reflectSchema' to support 'omitempty' in JSON tag

* add GenerateSchemaForType

* update json_test.go

* update `Warp` to `Wrap`

* fix Sanity check

* fix Sanity check

* update api_internal_test.go

* update README.md

* update README.md

* remove jsonschema.SchemaWrapper

* remove jsonschema.SchemaWrapper

* fix Sanity check

* optimize code formatting
This commit is contained in:
eiixy
2024-08-25 01:06:08 +08:00
committed by GitHub
parent 5162adbbf9
commit a3bd2569ac
7 changed files with 412 additions and 30 deletions

View File

@@ -4,7 +4,13 @@
// and/or pass in the schema in []byte format.
package jsonschema
import "encoding/json"
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
)
type DataType string
@@ -42,7 +48,7 @@ type Definition struct {
AdditionalProperties any `json:"additionalProperties,omitempty"`
}
func (d Definition) MarshalJSON() ([]byte, error) {
func (d *Definition) MarshalJSON() ([]byte, error) {
if d.Properties == nil {
d.Properties = make(map[string]Definition)
}
@@ -50,6 +56,99 @@ func (d Definition) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Alias
}{
Alias: (Alias)(d),
Alias: (Alias)(*d),
})
}
func (d *Definition) Unmarshal(content string, v any) error {
return VerifySchemaAndUnmarshal(*d, []byte(content), v)
}
func GenerateSchemaForType(v any) (*Definition, error) {
return reflectSchema(reflect.TypeOf(v))
}
func reflectSchema(t reflect.Type) (*Definition, error) {
var d Definition
switch t.Kind() {
case reflect.String:
d.Type = String
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
d.Type = Integer
case reflect.Float32, reflect.Float64:
d.Type = Number
case reflect.Bool:
d.Type = Boolean
case reflect.Slice, reflect.Array:
d.Type = Array
items, err := reflectSchema(t.Elem())
if err != nil {
return nil, err
}
d.Items = items
case reflect.Struct:
d.Type = Object
d.AdditionalProperties = false
object, err := reflectSchemaObject(t)
if err != nil {
return nil, err
}
d = *object
case reflect.Ptr:
definition, err := reflectSchema(t.Elem())
if err != nil {
return nil, err
}
d = *definition
case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128,
reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
reflect.UnsafePointer:
return nil, fmt.Errorf("unsupported type: %s", t.Kind().String())
default:
}
return &d, nil
}
func reflectSchemaObject(t reflect.Type) (*Definition, error) {
var d = Definition{
Type: Object,
AdditionalProperties: false,
}
properties := make(map[string]Definition)
var requiredFields []string
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if !field.IsExported() {
continue
}
jsonTag := field.Tag.Get("json")
var required = true
if jsonTag == "" {
jsonTag = field.Name
} else if strings.HasSuffix(jsonTag, ",omitempty") {
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
required = false
}
item, err := reflectSchema(field.Type)
if err != nil {
return nil, err
}
description := field.Tag.Get("description")
if description != "" {
item.Description = description
}
properties[jsonTag] = *item
if s := field.Tag.Get("required"); s != "" {
required, _ = strconv.ParseBool(s)
}
if required {
requiredFields = append(requiredFields, jsonTag)
}
}
d.Required = requiredFields
d.Properties = properties
return &d, nil
}

89
jsonschema/validate.go Normal file
View File

@@ -0,0 +1,89 @@
package jsonschema
import (
"encoding/json"
"errors"
)
func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error {
var data any
err := json.Unmarshal(content, &data)
if err != nil {
return err
}
if !Validate(schema, data) {
return errors.New("data validation failed against the provided schema")
}
return json.Unmarshal(content, &v)
}
func Validate(schema Definition, data any) bool {
switch schema.Type {
case Object:
return validateObject(schema, data)
case Array:
return validateArray(schema, data)
case String:
_, ok := data.(string)
return ok
case Number: // float64 and int
_, ok := data.(float64)
if !ok {
_, ok = data.(int)
}
return ok
case Boolean:
_, ok := data.(bool)
return ok
case Integer:
_, ok := data.(int)
return ok
case Null:
return data == nil
default:
return false
}
}
func validateObject(schema Definition, data any) bool {
dataMap, ok := data.(map[string]any)
if !ok {
return false
}
for _, field := range schema.Required {
if _, exists := dataMap[field]; !exists {
return false
}
}
for key, valueSchema := range schema.Properties {
value, exists := dataMap[key]
if exists && !Validate(valueSchema, value) {
return false
} else if !exists && contains(schema.Required, key) {
return false
}
}
return true
}
func validateArray(schema Definition, data any) bool {
dataArray, ok := data.([]any)
if !ok {
return false
}
for _, item := range dataArray {
if !Validate(*schema.Items, item) {
return false
}
}
return true
}
func contains[S ~[]E, E comparable](s S, v E) bool {
for i := range s {
if v == s[i] {
return true
}
}
return false
}

136
jsonschema/validate_test.go Normal file
View File

@@ -0,0 +1,136 @@
package jsonschema_test
import (
"testing"
"github.com/sashabaranov/go-openai/jsonschema"
)
func Test_Validate(t *testing.T) {
type args struct {
data any
schema jsonschema.Definition
}
tests := []struct {
name string
args args
want bool
}{
// string integer number boolean
{"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true},
{"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false},
{"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true},
{"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false},
{"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false},
{"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true},
{"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true},
{"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false},
{"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true},
{"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false},
// array
{"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{
Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}},
}, true},
{"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{
Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}},
}, false},
{"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{
Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}},
}, true},
{"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{
Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}},
}, false},
// object
{"", args{data: map[string]any{
"string": "abc",
"integer": 123,
"number": 123.4,
"boolean": false,
"array": []any{1, 2, 3},
}, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{
"string": {Type: jsonschema.String},
"integer": {Type: jsonschema.Integer},
"number": {Type: jsonschema.Number},
"boolean": {Type: jsonschema.Boolean},
"array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}},
},
Required: []string{"string"},
}}, true},
{"", args{data: map[string]any{
"integer": 123,
"number": 123.4,
"boolean": false,
"array": []any{1, 2, 3},
}, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{
"string": {Type: jsonschema.String},
"integer": {Type: jsonschema.Integer},
"number": {Type: jsonschema.Number},
"boolean": {Type: jsonschema.Boolean},
"array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}},
},
Required: []string{"string"},
}}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want {
t.Errorf("Validate() = %v, want %v", got, tt.want)
}
})
}
}
func TestUnmarshal(t *testing.T) {
type args struct {
schema jsonschema.Definition
content []byte
v any
}
var result1 struct {
String string `json:"string"`
Number float64 `json:"number"`
}
var result2 struct {
String string `json:"string"`
Number float64 `json:"number"`
}
tests := []struct {
name string
args args
wantErr bool
}{
{"", args{
schema: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"string": {Type: jsonschema.String},
"number": {Type: jsonschema.Number},
},
},
content: []byte(`{"string":"abc","number":123.4}`),
v: &result1,
}, false},
{"", args{
schema: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"string": {Type: jsonschema.String},
"number": {Type: jsonschema.Number},
},
Required: []string{"string", "number"},
},
content: []byte(`{"string":"abc"}`),
v: result2,
}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
} else if err == nil {
t.Logf("Unmarshal() v = %+v\n", tt.args.v)
}
})
}
}