From 794a5512f60c74bb693eed684d04a4387f1e389d Mon Sep 17 00:00:00 2001 From: sm2642 <123341426+sm2642@users.noreply.github.com> Date: Sat, 28 Jan 2023 10:25:38 -0800 Subject: [PATCH] -Added moderation endpoint test (#56) -Rearrange some code Co-authored-by: Shalin --- api_test.go | 153 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 119 insertions(+), 34 deletions(-) diff --git a/api_test.go b/api_test.go index 1e1c5d0..8848d48 100644 --- a/api_test.go +++ b/api_test.go @@ -121,6 +121,30 @@ func TestEdits(t *testing.T) { } } +// TestModeration Tests the moderations endpoint of the API using the mocked server. +func TestModerations(t *testing.T) { + // create the test server + var err error + ts := OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(testAPIToken) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + // create an edit request + model := "text-moderation-stable" + moderationReq := ModerationRequest{ + Model: &model, + Input: "I want to kill them.", + } + _, err = client.Moderations(ctx, moderationReq) + if err != nil { + t.Fatalf("Moderation error: %v", err) + } +} + func TestEmbedding(t *testing.T) { embeddedModels := []EmbeddingModel{ AdaSimilarity, @@ -160,6 +184,25 @@ func TestEmbedding(t *testing.T) { } } +func TestImages(t *testing.T) { + // create the test server + var err error + ts := OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(testAPIToken) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + req := ImageRequest{} + req.Prompt = "Lorem ipsum" + _, err = client.CreateImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + // getEditBody Returns the body of the request to create an edit. func getEditBody(r *http.Request) (EditsRequest, error) { edit := EditsRequest{} @@ -261,6 +304,21 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } +// getCompletionBody Returns the body of the request to create a completion. +func getCompletionBody(r *http.Request) (CompletionRequest, error) { + completion := CompletionRequest{} + // read the request body + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + return CompletionRequest{}, err + } + err = json.Unmarshal(reqBody, &completion) + if err != nil { + return CompletionRequest{}, err + } + return completion, nil +} + // handleImageEndpoint Handles the images endpoint by the test server. func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { var err error @@ -296,21 +354,6 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } -// getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} - // read the request body - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - return CompletionRequest{}, err - } - err = json.Unmarshal(reqBody, &completion) - if err != nil { - return CompletionRequest{}, err - } - return completion, nil -} - // getImageBody Returns the body of the request to create a image. func getImageBody(r *http.Request) (ImageRequest, error) { image := ImageRequest{} @@ -326,6 +369,65 @@ func getImageBody(r *http.Request) (ImageRequest, error) { return image, nil } +// handleModerationEndpoint Handles the moderation endpoint by the test server. +func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var moderationReq ModerationRequest + if moderationReq, err = getModerationBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + + resCat := ResultCategories{} + resCatScore := ResultCategoryScores{} + switch { + case strings.Contains(moderationReq.Input, "kill"): + resCat = ResultCategories{Violence: true} + resCatScore = ResultCategoryScores{Violence: 1} + case strings.Contains(moderationReq.Input, "hate"): + resCat = ResultCategories{Hate: true} + resCatScore = ResultCategoryScores{Hate: 1} + case strings.Contains(moderationReq.Input, "suicide"): + resCat = ResultCategories{SelfHarm: true} + resCatScore = ResultCategoryScores{SelfHarm: 1} + case strings.Contains(moderationReq.Input, "porn"): + resCat = ResultCategories{Sexual: true} + resCatScore = ResultCategoryScores{Sexual: 1} + } + + result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + + res := ModerationResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Model: *moderationReq.Model, + } + res.Results = append(res.Results, result) + + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getModerationBody Returns the body of the request to do a moderation. +func getModerationBody(r *http.Request) (ModerationRequest, error) { + moderation := ModerationRequest{} + // read the request body + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + return ModerationRequest{}, err + } + err = json.Unmarshal(reqBody, &moderation) + if err != nil { + return ModerationRequest{}, err + } + return moderation, nil +} + // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer @@ -335,25 +437,6 @@ func numTokens(s string) int { return int(float32(len(s)) / 4) } -func TestImages(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() - ts.Start() - defer ts.Close() - - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } -} - // OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. func OpenAITestServer() *httptest.Server { return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -373,6 +456,8 @@ func OpenAITestServer() *httptest.Server { case "/v1/completions": handleCompletionEndpoint(w, r) return + case "/v1/moderations": + handleModerationEndpoint(w, r) case "/v1/images/generations": handleImageEndpoint(w, r) // TODO: implement the other endpoints