Skip to content

Commit

Permalink
chore: use any for n_epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Oct 5, 2023
1 parent 2be7241 commit 581e933
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 39 deletions.
35 changes: 1 addition & 34 deletions fine_tuning_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package openai

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand All @@ -24,40 +23,8 @@ type FineTuningJob struct {
TrainedTokens int `json:"trained_tokens"`
}

type HyperparameterNEpochs struct {
IntValue *int `json:"-"`
StringValue *string `json:"-"`
}

func (h *HyperparameterNEpochs) UnmarshalJSON(data []byte) error {
var intValue int
var stringValue string

if err := json.Unmarshal(data, &intValue); err == nil {
h.IntValue = &intValue
return nil
}

if err := json.Unmarshal(data, &stringValue); err != nil {
return err
}

h.StringValue = &stringValue
return nil
}

func (h *HyperparameterNEpochs) MarshalJSON() ([]byte, error) {
if h.IntValue != nil {
return json.Marshal(*h.IntValue)
} else if h.StringValue != nil {
return json.Marshal(*h.StringValue)
}

return nil, fmt.Errorf("invalid hyperparameter n_epochs")
}

type Hyperparameters struct {
Epochs *HyperparameterNEpochs `json:"n_epochs,omitempty"`
Epochs any `json:"n_epochs,omitempty"`
}

type FineTuningJobRequest struct {
Expand Down
6 changes: 1 addition & 5 deletions fine_tuning_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ func TestFineTuningJob(t *testing.T) {
server.RegisterHandler(
"/v1/fine_tuning/jobs",
func(w http.ResponseWriter, r *http.Request) {
nEpochs := "auto"
resBytes, _ := json.Marshal(FineTuningJob{
Object: "fine_tuning.job",
ID: testFineTuninigJobID,
Expand All @@ -35,10 +34,7 @@ func TestFineTuningJob(t *testing.T) {
ValidationFile: "",
TrainingFile: "file-abc123",
Hyperparameters: Hyperparameters{
Epochs: &HyperparameterNEpochs{
IntValue: nil,
StringValue: &nEpochs,
},
Epochs: "auto",
},
TrainedTokens: 5768,
})
Expand Down

0 comments on commit 581e933

Please sign in to comment.