Skip to content

Commit

Permalink
serialization fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CalebCourier committed Jun 24, 2024
1 parent b70eedd commit 107bc8b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 10 deletions.
2 changes: 1 addition & 1 deletion resources/py/pyproject.toml.template
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "guardrails-api-client"
version = "0.3.6"
version = "0.3.7"
description = "Guardrails API Client."
authors = [
{name = "Guardrails AI", email = "contact@guardrailsai.com"}
Expand Down
50 changes: 44 additions & 6 deletions resources/py/scripts/prebuild.js
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ function fixPassResult() {
const passResult = passResultFile
.replace(
'_obj = cls.model_validate({',
'_obj = cls.model_validate({\n\t\t\t"metadata": obj.get("metadata"),\n\t\t\t"value_override": obj.get("value_override"),'
// The python formatter suddenly decided it didn't like tab characters
// even though they're fine later on in the string...
'\n if obj.get("outcome") != "pass":\n raise ValueError("PassResult must have an outcome value of \\"pass\\"!")\n _obj = cls.model_validate({\n\t\t\t"metadata": obj.get("metadata"),\n\t\t\t"value_override": obj.get("valueOverride"),'
)

if (passResultFile === passResult) {
Expand All @@ -142,7 +144,9 @@ function fixFailResult() {
const failResult = failResultFile
.replace(
'_obj = cls.model_validate({',
'_obj = cls.model_validate({\n\t\t\t"error_message": obj.get("error_message"),\n\t\t\t"fix_value": obj.get("fix_value"),\n\t\t\t"error_spans": [ErrorSpan.from_dict(es) for es in obj.get("error_spans", [])],\n\t\t\t"metadata": obj.get("metadata"),'
// The python formatter suddenly decided it didn't like tab characters
// even though they're fine later on in the string...
'\n if obj.get("outcome") != "fail":\n raise ValueError("FailResult must have an outcome value of \\"fail\\"!")\n _obj = cls.model_validate({\n\t\t\t"error_message": obj.get("errorMessage"),\n\t\t\t"fix_value": obj.get("fixValue"),\n\t\t\t"error_spans": [ErrorSpan.from_dict(es) for es in obj.get("errorSpans", [])],\n\t\t\t"metadata": obj.get("metadata"),'
)

if (failResultFile === failResult) {
Expand Down Expand Up @@ -271,6 +275,38 @@ function fixValidatorReferenceTypes () {
fs.writeFileSync(validatorReferenceFilePath, validatorReference)
}

function fixInputs() {
const inputsFilePath = path.resolve('./guardrails_api_client/models/inputs.py');
const inputsFile = fs.readFileSync(inputsFilePath).toString();
const inputs = inputsFile
.replace(
'_obj = cls.model_validate({',
'_obj = cls.model_validate({\n\t\t\t"promptParams": obj.get("promptParams"),\n\t\t\t"metadata": obj.get("metadata"),'
)

if (inputsFile === inputs) {
console.warn("Fixes in fixInputs may no longer be necessary!")
}

fs.writeFileSync(inputsFilePath, inputs)
}

function fixCallInputs() {
const callInputsFilePath = path.resolve('./guardrails_api_client/models/call_inputs.py');
const callInputsFile = fs.readFileSync(callInputsFilePath).toString();
const callInputs = callInputsFile
.replace(
'_obj = cls.model_validate({',
'_obj = cls.model_validate({\n\t\t\t"promptParams": obj.get("promptParams"),\n\t\t\t"metadata": obj.get("metadata"),\n\t\t\t"kwargs": obj.get("kwargs"),'
)

if (callInputsFile === callInputs) {
console.warn("Fixes in fixInputs may no longer be necessary!")
}

fs.writeFileSync(callInputsFilePath, callInputs)
}

function exportAll (filePath) {
const initFilePath = path.resolve(filePath);
const initFile = fs.readFileSync(initFilePath).toString();
Expand Down Expand Up @@ -307,6 +343,8 @@ function hotFixes () {
fixModelSchemaDefaults();
fixCallException();
fixValidatorReferenceTypes();
fixInputs();
fixCallInputs();
fixInits();
}

Expand All @@ -329,8 +367,8 @@ function globalReplacements () {
)
// TODO: Find a regex for these
.replace(
'"validated_chunk": object.from_dict(obj["validated_chunk"]) if obj.get("validated_chunk") is not None else None',
'"validated_chunk": obj.get("validated_chunk")'
'"validatedChunk": object.from_dict(obj["validatedChunk"]) if obj.get("validatedChunk") is not None else None',
'"validated_chunk": obj.get("validatedChunk")'
)
.replace(
'"incorrectValue": object.from_dict(obj["incorrectValue"]) if obj.get("incorrectValue") is not None else None,',
Expand All @@ -349,8 +387,8 @@ function globalReplacements () {
'_items.append(_item.to_dict() if hasattr(_item, "to_dict") and callable(_item.to_dict) else _item)'
)
.replace(
"_dict['validated_chunk'] = self.validated_chunk.to_dict()",
"_dict['validated_chunk'] = self.validated_chunk"
"_dict['validatedChunk'] = self.validated_chunk.to_dict()",
"_dict['validatedChunk'] = self.validated_chunk"
)
.replace(
"_dict['incorrectValue'] = self.incorrect_value.to_dict()",
Expand Down
4 changes: 2 additions & 2 deletions resources/ts/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion resources/ts/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@guardrails-ai/api-client",
"version": "0.3.6",
"version": "0.3.7",
"description": "Client libaray for interacting with the guardrails-api",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down

0 comments on commit 107bc8b

Please sign in to comment.