diff --git a/core/src/main/java/feast/core/grpc/CoreServiceImpl.java b/core/src/main/java/feast/core/grpc/CoreServiceImpl.java index b8d0670d0d..661bbe2403 100644 --- a/core/src/main/java/feast/core/grpc/CoreServiceImpl.java +++ b/core/src/main/java/feast/core/grpc/CoreServiceImpl.java @@ -16,6 +16,7 @@ */ package feast.core.grpc; +import com.google.protobuf.InvalidProtocolBufferException; import feast.core.CoreServiceGrpc.CoreServiceImplBase; import feast.core.CoreServiceProto.ApplyFeatureSetRequest; import feast.core.CoreServiceProto.ApplyFeatureSetResponse; @@ -77,7 +78,7 @@ public void getFeatureSet( GetFeatureSetResponse response = specService.getFeatureSet(request); responseObserver.onNext(response); responseObserver.onCompleted(); - } catch (RetrievalException | StatusRuntimeException e) { + } catch (RetrievalException | StatusRuntimeException | InvalidProtocolBufferException e) { log.error("Exception has occurred in GetFeatureSet method: ", e); responseObserver.onError( Status.INTERNAL.withDescription(e.getMessage()).withCause(e).asRuntimeException()); @@ -91,7 +92,7 @@ public void listFeatureSets( ListFeatureSetsResponse response = specService.listFeatureSets(request.getFilter()); responseObserver.onNext(response); responseObserver.onCompleted(); - } catch (RetrievalException | IllegalArgumentException e) { + } catch (RetrievalException | IllegalArgumentException | InvalidProtocolBufferException e) { log.error("Exception has occurred in ListFeatureSet method: ", e); responseObserver.onError( Status.INTERNAL.withDescription(e.getMessage()).withCause(e).asRuntimeException()); diff --git a/core/src/main/java/feast/core/job/JobUpdateTask.java b/core/src/main/java/feast/core/job/JobUpdateTask.java index 57b2dfee4f..87578cce25 100644 --- a/core/src/main/java/feast/core/job/JobUpdateTask.java +++ b/core/src/main/java/feast/core/job/JobUpdateTask.java @@ -173,6 +173,7 @@ private Job startJob( return job; } catch (Exception e) { + log.error(e.getMessage()); AuditLogger.log( Resource.JOB, jobId, diff --git a/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java b/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java index 2de46ae1f2..7115ee3f66 100644 --- a/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java +++ b/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java @@ -78,17 +78,24 @@ public Runner getRunnerType() { @Override public Job startJob(Job job) { - List featureSetProtos = - job.getFeatureSets().stream().map(FeatureSet::toProto).collect(Collectors.toList()); try { + List featureSetProtos = new ArrayList<>(); + for (FeatureSet featureSet : job.getFeatureSets()) { + featureSetProtos.add(featureSet.toProto()); + } return submitDataflowJob( job.getId(), featureSetProtos, job.getSource().toProto(), job.getStore().toProto(), false); + } catch (InvalidProtocolBufferException e) { - throw new RuntimeException(String.format("Unable to start job %s", job.getId()), e); + log.error(e.getMessage()); + throw new IllegalArgumentException( + String.format("DataflowJobManager failed to START job with id '%s' because the job" + + "has an invalid spec. Please check the FeatureSet, Source and Store specs. Actual error message: %s", + job.getId(), e.getMessage())); } } @@ -101,14 +108,18 @@ public Job startJob(Job job) { @Override public Job updateJob(Job job) { try { - List featureSetProtos = - job.getFeatureSets().stream().map(FeatureSet::toProto).collect(Collectors.toList()); - - return submitDataflowJob( - job.getId(), featureSetProtos, job.getSource().toProto(), job.getStore().toProto(), true); - + List featureSetProtos = new ArrayList<>(); + for (FeatureSet featureSet : job.getFeatureSets()) { + featureSetProtos.add(featureSet.toProto()); + } + return submitDataflowJob(job.getId(), featureSetProtos, job.getSource().toProto(), + job.getStore().toProto(), true); } catch (InvalidProtocolBufferException e) { - throw new RuntimeException(String.format("Unable to update job %s", job.getId()), e); + log.error(e.getMessage()); + throw new IllegalArgumentException( + String.format("DataflowJobManager failed to UPDATE job with id '%s' because the job" + + "has an invalid spec. Please check the FeatureSet, Source and Store specs. Actual error message: %s", + job.getId(), e.getMessage())); } } diff --git a/core/src/main/java/feast/core/job/direct/DirectRunnerJobManager.java b/core/src/main/java/feast/core/job/direct/DirectRunnerJobManager.java index fdf3aad9bc..b01d37d892 100644 --- a/core/src/main/java/feast/core/job/direct/DirectRunnerJobManager.java +++ b/core/src/main/java/feast/core/job/direct/DirectRunnerJobManager.java @@ -21,7 +21,6 @@ import com.google.protobuf.util.JsonFormat; import com.google.protobuf.util.JsonFormat.Printer; import feast.core.FeatureSetProto; -import feast.core.FeatureSetProto.FeatureSetSpec; import feast.core.StoreProto; import feast.core.config.FeastProperties.MetricsProperties; import feast.core.exception.JobExecutionException; @@ -38,7 +37,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.beam.runners.direct.DirectRunner; import org.apache.beam.sdk.PipelineResult; @@ -75,8 +73,10 @@ public Runner getRunnerType() { @Override public Job startJob(Job job) { try { - List featureSetProtos = - job.getFeatureSets().stream().map(FeatureSet::toProto).collect(Collectors.toList()); + List featureSetProtos = new ArrayList<>(); + for (FeatureSet featureSet : job.getFeatureSets()) { + featureSetProtos.add(featureSet.toProto()); + } ImportOptions pipelineOptions = getPipelineOptions(featureSetProtos, job.getStore().toProto()); PipelineResult pipelineResult = runPipeline(pipelineOptions); @@ -131,10 +131,6 @@ public Job updateJob(Job job) { String jobId = job.getExtId(); abortJob(jobId); try { - List featureSetSpecs = new ArrayList<>(); - for (FeatureSet featureSet : job.getFeatureSets()) { - featureSetSpecs.add(featureSet.toProto().getSpec()); - } return startJob(job); } catch (JobExecutionException e) { throw new JobExecutionException(String.format("Error running ingestion job: %s", e), e); diff --git a/core/src/main/java/feast/core/model/FeatureSet.java b/core/src/main/java/feast/core/model/FeatureSet.java index e468705020..cd6036fe5e 100644 --- a/core/src/main/java/feast/core/model/FeatureSet.java +++ b/core/src/main/java/feast/core/model/FeatureSet.java @@ -17,6 +17,7 @@ package feast.core.model; import com.google.protobuf.Duration; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Timestamp; import feast.core.FeatureSetProto; import feast.core.FeatureSetProto.EntitySpec; @@ -24,7 +25,7 @@ import feast.core.FeatureSetProto.FeatureSetSpec; import feast.core.FeatureSetProto.FeatureSetStatus; import feast.core.FeatureSetProto.FeatureSpec; -import feast.types.ValueProto.ValueType; +import feast.types.ValueProto.ValueType.Enum; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -47,6 +48,20 @@ import org.apache.commons.lang3.builder.HashCodeBuilder; import org.hibernate.annotations.Fetch; import org.hibernate.annotations.FetchMode; +import org.tensorflow.metadata.v0.BoolDomain; +import org.tensorflow.metadata.v0.FeaturePresence; +import org.tensorflow.metadata.v0.FeaturePresenceWithinGroup; +import org.tensorflow.metadata.v0.FixedShape; +import org.tensorflow.metadata.v0.FloatDomain; +import org.tensorflow.metadata.v0.ImageDomain; +import org.tensorflow.metadata.v0.IntDomain; +import org.tensorflow.metadata.v0.NaturalLanguageDomain; +import org.tensorflow.metadata.v0.StringDomain; +import org.tensorflow.metadata.v0.StructDomain; +import org.tensorflow.metadata.v0.TimeDomain; +import org.tensorflow.metadata.v0.TimeOfDayDomain; +import org.tensorflow.metadata.v0.URLDomain; +import org.tensorflow.metadata.v0.ValueCount; @Getter @Setter @@ -157,14 +172,14 @@ public static FeatureSet fromProto(FeatureSetProto.FeatureSet featureSetProto) { FeatureSetSpec featureSetSpec = featureSetProto.getSpec(); Source source = Source.fromProto(featureSetSpec.getSource()); - List features = new ArrayList<>(); - for (FeatureSpec feature : featureSetSpec.getFeaturesList()) { - features.add(new Field(feature.getName(), feature.getValueType())); + List featureSpecs = new ArrayList<>(); + for (FeatureSpec featureSpec : featureSetSpec.getFeaturesList()) { + featureSpecs.add(new Field(featureSpec)); } - List entities = new ArrayList<>(); - for (EntitySpec entity : featureSetSpec.getEntitiesList()) { - entities.add(new Field(entity.getName(), entity.getValueType())); + List entitySpecs = new ArrayList<>(); + for (EntitySpec entitySpec : featureSetSpec.getEntitiesList()) { + entitySpecs.add(new Field(entitySpec)); } return new FeatureSet( @@ -172,8 +187,8 @@ public static FeatureSet fromProto(FeatureSetProto.FeatureSet featureSetProto) { featureSetProto.getSpec().getProject(), featureSetProto.getSpec().getVersion(), featureSetSpec.getMaxAge().getSeconds(), - entities, - features, + entitySpecs, + featureSpecs, source, featureSetProto.getMeta().getStatus()); } @@ -202,24 +217,21 @@ public void addFeature(Field field) { features.add(field); } - public FeatureSetProto.FeatureSet toProto() { + public FeatureSetProto.FeatureSet toProto() throws InvalidProtocolBufferException { List entitySpecs = new ArrayList<>(); - for (Field entity : entities) { - entitySpecs.add( - EntitySpec.newBuilder() - .setName(entity.getName()) - .setValueType(ValueType.Enum.valueOf(entity.getType())) - .build()); + for (Field entityField : entities) { + EntitySpec.Builder entitySpecBuilder = EntitySpec.newBuilder(); + setEntitySpecFields(entitySpecBuilder, entityField); + entitySpecs.add(entitySpecBuilder.build()); } List featureSpecs = new ArrayList<>(); - for (Field feature : features) { - featureSpecs.add( - FeatureSpec.newBuilder() - .setName(feature.getName()) - .setValueType(ValueType.Enum.valueOf(feature.getType())) - .build()); + for (Field featureField : features) { + FeatureSpec.Builder featureSpecBuilder = FeatureSpec.newBuilder(); + setFeatureSpecFields(featureSpecBuilder, featureField); + featureSpecs.add(featureSpecBuilder.build()); } + FeatureSetMeta.Builder meta = FeatureSetMeta.newBuilder() .setCreatedTimestamp( @@ -239,6 +251,108 @@ public FeatureSetProto.FeatureSet toProto() { return FeatureSetProto.FeatureSet.newBuilder().setMeta(meta).setSpec(spec).build(); } + // setEntitySpecFields and setFeatureSpecFields methods contain duplicated code because + // Feast internally treat EntitySpec and FeatureSpec as Field class. However, the proto message + // builder for EntitySpec and FeatureSpec are of different class. + @SuppressWarnings("DuplicatedCode") + private void setEntitySpecFields(EntitySpec.Builder entitySpecBuilder, Field entityField) + throws InvalidProtocolBufferException { + entitySpecBuilder + .setName(entityField.getName()) + .setValueType(Enum.valueOf(entityField.getType())); + + if (entityField.getPresence() != null) { + entitySpecBuilder.setPresence(FeaturePresence.parseFrom(entityField.getPresence())); + } else if (entityField.getGroupPresence() != null) { + entitySpecBuilder + .setGroupPresence(FeaturePresenceWithinGroup.parseFrom(entityField.getGroupPresence())); + } + + if (entityField.getShape() != null) { + entitySpecBuilder.setShape(FixedShape.parseFrom(entityField.getShape())); + } else if (entityField.getValueCount() != null) { + entitySpecBuilder.setValueCount(ValueCount.parseFrom(entityField.getValueCount())); + } + + if (entityField.getDomain() != null) { + entitySpecBuilder.setDomain(entityField.getDomain()); + } else if (entityField.getIntDomain() != null) { + entitySpecBuilder.setIntDomain(IntDomain.parseFrom(entityField.getIntDomain())); + } else if (entityField.getFloatDomain() != null) { + entitySpecBuilder.setFloatDomain(FloatDomain.parseFrom(entityField.getFloatDomain())); + } else if (entityField.getStringDomain() != null) { + entitySpecBuilder.setStringDomain(StringDomain.parseFrom(entityField.getStringDomain())); + } else if (entityField.getBoolDomain() != null) { + entitySpecBuilder.setBoolDomain(BoolDomain.parseFrom(entityField.getBoolDomain())); + } else if (entityField.getStructDomain() != null) { + entitySpecBuilder.setStructDomain(StructDomain.parseFrom(entityField.getStructDomain())); + } else if (entityField.getNaturalLanguageDomain() != null) { + entitySpecBuilder.setNaturalLanguageDomain( + NaturalLanguageDomain.parseFrom(entityField.getNaturalLanguageDomain())); + } else if (entityField.getImageDomain() != null) { + entitySpecBuilder.setImageDomain(ImageDomain.parseFrom(entityField.getImageDomain())); + } else if (entityField.getMidDomain() != null) { + entitySpecBuilder.setIntDomain(IntDomain.parseFrom(entityField.getIntDomain())); + } else if (entityField.getUrlDomain() != null) { + entitySpecBuilder.setUrlDomain(URLDomain.parseFrom(entityField.getUrlDomain())); + } else if (entityField.getTimeDomain() != null) { + entitySpecBuilder.setTimeDomain(TimeDomain.parseFrom(entityField.getTimeDomain())); + } else if (entityField.getTimeOfDayDomain() != null) { + entitySpecBuilder + .setTimeOfDayDomain(TimeOfDayDomain.parseFrom(entityField.getTimeOfDayDomain())); + } + } + + // Refer to setEntitySpecFields method for the reason for code duplication. + @SuppressWarnings("DuplicatedCode") + private void setFeatureSpecFields(FeatureSpec.Builder featureSpecBuilder, Field featureField) + throws InvalidProtocolBufferException { + featureSpecBuilder + .setName(featureField.getName()) + .setValueType(Enum.valueOf(featureField.getType())); + + if (featureField.getPresence() != null) { + featureSpecBuilder.setPresence(FeaturePresence.parseFrom(featureField.getPresence())); + } else if (featureField.getGroupPresence() != null) { + featureSpecBuilder + .setGroupPresence(FeaturePresenceWithinGroup.parseFrom(featureField.getGroupPresence())); + } + + if (featureField.getShape() != null) { + featureSpecBuilder.setShape(FixedShape.parseFrom(featureField.getShape())); + } else if (featureField.getValueCount() != null) { + featureSpecBuilder.setValueCount(ValueCount.parseFrom(featureField.getValueCount())); + } + + if (featureField.getDomain() != null) { + featureSpecBuilder.setDomain(featureField.getDomain()); + } else if (featureField.getIntDomain() != null) { + featureSpecBuilder.setIntDomain(IntDomain.parseFrom(featureField.getIntDomain())); + } else if (featureField.getFloatDomain() != null) { + featureSpecBuilder.setFloatDomain(FloatDomain.parseFrom(featureField.getFloatDomain())); + } else if (featureField.getStringDomain() != null) { + featureSpecBuilder.setStringDomain(StringDomain.parseFrom(featureField.getStringDomain())); + } else if (featureField.getBoolDomain() != null) { + featureSpecBuilder.setBoolDomain(BoolDomain.parseFrom(featureField.getBoolDomain())); + } else if (featureField.getStructDomain() != null) { + featureSpecBuilder.setStructDomain(StructDomain.parseFrom(featureField.getStructDomain())); + } else if (featureField.getNaturalLanguageDomain() != null) { + featureSpecBuilder.setNaturalLanguageDomain( + NaturalLanguageDomain.parseFrom(featureField.getNaturalLanguageDomain())); + } else if (featureField.getImageDomain() != null) { + featureSpecBuilder.setImageDomain(ImageDomain.parseFrom(featureField.getImageDomain())); + } else if (featureField.getMidDomain() != null) { + featureSpecBuilder.setIntDomain(IntDomain.parseFrom(featureField.getIntDomain())); + } else if (featureField.getUrlDomain() != null) { + featureSpecBuilder.setUrlDomain(URLDomain.parseFrom(featureField.getUrlDomain())); + } else if (featureField.getTimeDomain() != null) { + featureSpecBuilder.setTimeDomain(TimeDomain.parseFrom(featureField.getTimeDomain())); + } else if (featureField.getTimeOfDayDomain() != null) { + featureSpecBuilder + .setTimeOfDayDomain(TimeOfDayDomain.parseFrom(featureField.getTimeOfDayDomain())); + } + } + /** * Checks if the given featureSet's schema and source has is different from this one. * diff --git a/core/src/main/java/feast/core/model/Field.java b/core/src/main/java/feast/core/model/Field.java index 7573fcbf5e..edb0a73acb 100644 --- a/core/src/main/java/feast/core/model/Field.java +++ b/core/src/main/java/feast/core/model/Field.java @@ -16,6 +16,8 @@ */ package feast.core.model; +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSpec; import feast.types.ValueProto.ValueType; import java.util.Objects; import javax.persistence.Column; @@ -44,13 +46,175 @@ public class Field { @Column(name = "project") private String project; - public Field() {} + // Presence constraints (refer to proto feast.core.FeatureSet.FeatureSpec) + // Only one of them can be set. + private byte[] presence; + private byte[] groupPresence; + + // Shape type (refer to proto feast.core.FeatureSet.FeatureSpec) + // Only one of them can be set. + private byte[] shape; + private byte[] valueCount; + + // Domain info for the values (refer to proto feast.core.FeatureSet.FeatureSpec) + // Only one of them can be set. + private String domain; + private byte[] intDomain; + private byte[] floatDomain; + private byte[] stringDomain; + private byte[] boolDomain; + private byte[] structDomain; + private byte[] naturalLanguageDomain; + private byte[] imageDomain; + private byte[] midDomain; + private byte[] urlDomain; + private byte[] timeDomain; + private byte[] timeOfDayDomain; + + public Field() { + } public Field(String name, ValueType.Enum type) { this.name = name; this.type = type.toString(); } + public Field(FeatureSpec featureSpec) { + this.name = featureSpec.getName(); + this.type = featureSpec.getValueType().toString(); + + switch (featureSpec.getPresenceConstraintsCase()) { + case PRESENCE: + this.presence = featureSpec.getPresence().toByteArray(); + break; + case GROUP_PRESENCE: + this.groupPresence = featureSpec.getGroupPresence().toByteArray(); + break; + case PRESENCECONSTRAINTS_NOT_SET: + break; + } + + switch (featureSpec.getShapeTypeCase()) { + case SHAPE: + this.shape = featureSpec.getShape().toByteArray(); + break; + case VALUE_COUNT: + this.valueCount = featureSpec.getValueCount().toByteArray(); + break; + case SHAPETYPE_NOT_SET: + break; + } + + switch (featureSpec.getDomainInfoCase()) { + case DOMAIN: + this.domain = featureSpec.getDomain(); + break; + case INT_DOMAIN: + this.intDomain = featureSpec.getIntDomain().toByteArray(); + break; + case FLOAT_DOMAIN: + this.floatDomain = featureSpec.getFloatDomain().toByteArray(); + break; + case STRING_DOMAIN: + this.stringDomain = featureSpec.getStringDomain().toByteArray(); + break; + case BOOL_DOMAIN: + this.boolDomain = featureSpec.getBoolDomain().toByteArray(); + break; + case STRUCT_DOMAIN: + this.structDomain = featureSpec.getStructDomain().toByteArray(); + break; + case NATURAL_LANGUAGE_DOMAIN: + this.naturalLanguageDomain = featureSpec.getNaturalLanguageDomain().toByteArray(); + break; + case IMAGE_DOMAIN: + this.imageDomain = featureSpec.getImageDomain().toByteArray(); + break; + case MID_DOMAIN: + this.midDomain = featureSpec.getMidDomain().toByteArray(); + break; + case URL_DOMAIN: + this.urlDomain = featureSpec.getUrlDomain().toByteArray(); + break; + case TIME_DOMAIN: + this.timeDomain = featureSpec.getTimeDomain().toByteArray(); + break; + case TIME_OF_DAY_DOMAIN: + this.timeOfDayDomain = featureSpec.getTimeOfDayDomain().toByteArray(); + break; + case DOMAININFO_NOT_SET: + break; + } + } + + public Field(EntitySpec entitySpec) { + this.name = entitySpec.getName(); + this.type = entitySpec.getValueType().toString(); + + switch (entitySpec.getPresenceConstraintsCase()) { + case PRESENCE: + this.presence = entitySpec.getPresence().toByteArray(); + break; + case GROUP_PRESENCE: + this.groupPresence = entitySpec.getGroupPresence().toByteArray(); + break; + case PRESENCECONSTRAINTS_NOT_SET: + break; + } + + switch (entitySpec.getShapeTypeCase()) { + case SHAPE: + this.shape = entitySpec.getShape().toByteArray(); + break; + case VALUE_COUNT: + this.valueCount = entitySpec.getValueCount().toByteArray(); + break; + case SHAPETYPE_NOT_SET: + break; + } + + switch (entitySpec.getDomainInfoCase()) { + case DOMAIN: + this.domain = entitySpec.getDomain(); + break; + case INT_DOMAIN: + this.intDomain = entitySpec.getIntDomain().toByteArray(); + break; + case FLOAT_DOMAIN: + this.floatDomain = entitySpec.getFloatDomain().toByteArray(); + break; + case STRING_DOMAIN: + this.stringDomain = entitySpec.getStringDomain().toByteArray(); + break; + case BOOL_DOMAIN: + this.boolDomain = entitySpec.getBoolDomain().toByteArray(); + break; + case STRUCT_DOMAIN: + this.structDomain = entitySpec.getStructDomain().toByteArray(); + break; + case NATURAL_LANGUAGE_DOMAIN: + this.naturalLanguageDomain = entitySpec.getNaturalLanguageDomain().toByteArray(); + break; + case IMAGE_DOMAIN: + this.imageDomain = entitySpec.getImageDomain().toByteArray(); + break; + case MID_DOMAIN: + this.midDomain = entitySpec.getMidDomain().toByteArray(); + break; + case URL_DOMAIN: + this.urlDomain = entitySpec.getUrlDomain().toByteArray(); + break; + case TIME_DOMAIN: + this.timeDomain = entitySpec.getTimeDomain().toByteArray(); + break; + case TIME_OF_DAY_DOMAIN: + this.timeOfDayDomain = entitySpec.getTimeOfDayDomain().toByteArray(); + break; + case DOMAININFO_NOT_SET: + break; + } + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/core/src/main/java/feast/core/service/JobCoordinatorService.java b/core/src/main/java/feast/core/service/JobCoordinatorService.java index 23ad041b81..3678135a52 100644 --- a/core/src/main/java/feast/core/service/JobCoordinatorService.java +++ b/core/src/main/java/feast/core/service/JobCoordinatorService.java @@ -16,6 +16,7 @@ */ package feast.core.service; +import com.google.protobuf.InvalidProtocolBufferException; import feast.core.CoreServiceProto.ListFeatureSetsRequest; import feast.core.CoreServiceProto.ListStoresRequest.Filter; import feast.core.CoreServiceProto.ListStoresResponse; @@ -87,7 +88,7 @@ public JobCoordinatorService( */ @Transactional @Scheduled(fixedDelay = POLLING_INTERVAL_MILLISECONDS) - public void Poll() { + public void Poll() throws InvalidProtocolBufferException { log.info("Polling for new jobs..."); List jobUpdateTasks = new ArrayList<>(); ListStoresResponse listStoresResponse = specService.listStores(Filter.newBuilder().build()); diff --git a/core/src/main/java/feast/core/service/SpecService.java b/core/src/main/java/feast/core/service/SpecService.java index 129fa68a82..9016b692d1 100644 --- a/core/src/main/java/feast/core/service/SpecService.java +++ b/core/src/main/java/feast/core/service/SpecService.java @@ -86,7 +86,8 @@ public SpecService( * @param request: GetFeatureSetRequest Request containing filter parameters. * @return Returns a GetFeatureSetResponse containing a feature set.. */ - public GetFeatureSetResponse getFeatureSet(GetFeatureSetRequest request) { + public GetFeatureSetResponse getFeatureSet(GetFeatureSetRequest request) + throws InvalidProtocolBufferException { // Validate input arguments checkValidCharacters(request.getName(), "featureSetName"); @@ -150,7 +151,8 @@ public GetFeatureSetResponse getFeatureSet(GetFeatureSetRequest request) { * @param filter filter containing the desired featureSet name and version filter * @return ListFeatureSetsResponse with list of featureSets found matching the filter */ - public ListFeatureSetsResponse listFeatureSets(ListFeatureSetsRequest.Filter filter) { + public ListFeatureSetsResponse listFeatureSets(ListFeatureSetsRequest.Filter filter) + throws InvalidProtocolBufferException { String name = filter.getFeatureSetName(); String project = filter.getProject(); String version = filter.getFeatureSetVersion(); @@ -165,7 +167,8 @@ public ListFeatureSetsResponse listFeatureSets(ListFeatureSetsRequest.Filter fil checkValidCharactersAllowAsterisk(name, "featureSetName"); checkValidCharactersAllowAsterisk(project, "projectName"); - List featureSets = new ArrayList() {}; + List featureSets = new ArrayList() { + }; if (project.equals("*")) { // Matching all projects @@ -274,13 +277,14 @@ public ListStoresResponse listStores(ListStoresRequest.Filter filter) { * Creates or updates a feature set in the repository. If there is a change in the feature set * schema, then the feature set version will be incremented. * - *

This function is idempotent. If no changes are detected in the incoming featureSet's schema, - * this method will update the incoming featureSet spec with the latest version stored in the - * repository, and return that. + *

This function is idempotent. If no changes are detected in the incoming featureSet's + * schema, this method will update the incoming featureSet spec with the latest version stored in + * the repository, and return that. * * @param newFeatureSet Feature set that will be created or updated. */ - public ApplyFeatureSetResponse applyFeatureSet(FeatureSetProto.FeatureSet newFeatureSet) { + public ApplyFeatureSetResponse applyFeatureSet(FeatureSetProto.FeatureSet newFeatureSet) + throws InvalidProtocolBufferException { // Validate incoming feature set FeatureSetValidator.validateSpec(newFeatureSet); diff --git a/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java b/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java index 775cb028b0..67a87e9316 100644 --- a/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java +++ b/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java @@ -75,7 +75,7 @@ public void setUp() { } @Test - public void shouldDoNothingIfNoStoresFound() { + public void shouldDoNothingIfNoStoresFound() throws InvalidProtocolBufferException { when(specService.listStores(any())).thenReturn(ListStoresResponse.newBuilder().build()); JobCoordinatorService jcs = new JobCoordinatorService( diff --git a/core/src/test/java/feast/core/service/SpecServiceTest.java b/core/src/test/java/feast/core/service/SpecServiceTest.java index edd99aa494..c533f593e3 100644 --- a/core/src/test/java/feast/core/service/SpecServiceTest.java +++ b/core/src/test/java/feast/core/service/SpecServiceTest.java @@ -37,6 +37,7 @@ import feast.core.CoreServiceProto.UpdateStoreRequest; import feast.core.CoreServiceProto.UpdateStoreResponse; import feast.core.FeatureSetProto; +import feast.core.FeatureSetProto.EntitySpec; import feast.core.FeatureSetProto.FeatureSetSpec; import feast.core.FeatureSetProto.FeatureSetStatus; import feast.core.FeatureSetProto.FeatureSpec; @@ -61,6 +62,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -71,16 +73,28 @@ import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mock; +import org.tensorflow.metadata.v0.BoolDomain; +import org.tensorflow.metadata.v0.FeaturePresence; +import org.tensorflow.metadata.v0.FeaturePresenceWithinGroup; +import org.tensorflow.metadata.v0.FixedShape; +import org.tensorflow.metadata.v0.FloatDomain; +import org.tensorflow.metadata.v0.IntDomain; +import org.tensorflow.metadata.v0.StringDomain; +import org.tensorflow.metadata.v0.ValueCount; public class SpecServiceTest { - @Mock private FeatureSetRepository featureSetRepository; + @Mock + private FeatureSetRepository featureSetRepository; - @Mock private StoreRepository storeRepository; + @Mock + private StoreRepository storeRepository; - @Mock private ProjectRepository projectRepository; + @Mock + private ProjectRepository projectRepository; - @Rule public final ExpectedException expectedException = ExpectedException.none(); + @Rule + public final ExpectedException expectedException = ExpectedException.none(); private SpecService specService; private List featureSets; @@ -126,25 +140,25 @@ public void setUp() { when(featureSetRepository.findFeatureSetByNameAndProject_NameAndVersion("f1", "project1", 1)) .thenReturn(featureSets.get(0)); when(featureSetRepository.findAllByNameLikeAndProject_NameOrderByNameAscVersionAsc( - "f1", "project1")) + "f1", "project1")) .thenReturn(featureSets.subList(0, 3)); when(featureSetRepository.findAllByNameLikeAndProject_NameOrderByNameAscVersionAsc( - "f3", "project1")) + "f3", "project1")) .thenReturn(featureSets.subList(4, 5)); when(featureSetRepository.findFirstFeatureSetByNameLikeAndProject_NameOrderByVersionDesc( - "f1", "project1")) + "f1", "project1")) .thenReturn(featureSet1v3); when(featureSetRepository.findAllByNameLikeAndProject_NameOrderByNameAscVersionAsc( - "f1", "project1")) + "f1", "project1")) .thenReturn(featureSets.subList(0, 3)); when(featureSetRepository.findAllByNameLikeAndProject_NameOrderByNameAscVersionAsc( - "asd", "project1")) + "asd", "project1")) .thenReturn(Lists.newArrayList()); when(featureSetRepository.findAllByNameLikeAndProject_NameOrderByNameAscVersionAsc( - "f%", "project1")) + "f%", "project1")) .thenReturn(featureSets); when(featureSetRepository.findAllByNameLikeAndProject_NameLikeOrderByNameAscVersionAsc( - "%", "%")) + "%", "%")) .thenReturn(featureSets); when(projectRepository.findAllByArchivedIsFalse()) @@ -163,7 +177,8 @@ public void setUp() { } @Test - public void shouldGetAllFeatureSetsIfOnlyWildcardsProvided() { + public void shouldGetAllFeatureSetsIfOnlyWildcardsProvided() + throws InvalidProtocolBufferException { ListFeatureSetsResponse actual = specService.listFeatureSets( Filter.newBuilder() @@ -182,7 +197,8 @@ public void shouldGetAllFeatureSetsIfOnlyWildcardsProvided() { } @Test - public void listFeatureSetShouldFailIfFeatureSetProvidedWithoutProject() { + public void listFeatureSetShouldFailIfFeatureSetProvidedWithoutProject() + throws InvalidProtocolBufferException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage( "Invalid listFeatureSetRequest, missing arguments. Must provide project, feature set name, and version."); @@ -191,7 +207,8 @@ public void listFeatureSetShouldFailIfFeatureSetProvidedWithoutProject() { } @Test - public void shouldGetAllFeatureSetsMatchingNameIfWildcardVersionProvided() { + public void shouldGetAllFeatureSetsMatchingNameIfWildcardVersionProvided() + throws InvalidProtocolBufferException { ListFeatureSetsResponse actual = specService.listFeatureSets( Filter.newBuilder() @@ -212,7 +229,8 @@ public void shouldGetAllFeatureSetsMatchingNameIfWildcardVersionProvided() { } @Test - public void shouldGetAllFeatureSetsMatchingNameWithWildcardSearch() { + public void shouldGetAllFeatureSetsMatchingNameWithWildcardSearch() + throws InvalidProtocolBufferException { ListFeatureSetsResponse actual = specService.listFeatureSets( Filter.newBuilder() @@ -235,7 +253,8 @@ public void shouldGetAllFeatureSetsMatchingNameWithWildcardSearch() { } @Test - public void shouldGetAllFeatureSetsMatchingVersionIfNoComparator() { + public void shouldGetAllFeatureSetsMatchingVersionIfNoComparator() + throws InvalidProtocolBufferException { ListFeatureSetsResponse actual = specService.listFeatureSets( Filter.newBuilder() @@ -259,7 +278,8 @@ public void shouldGetAllFeatureSetsMatchingVersionIfNoComparator() { } @Test - public void shouldThrowExceptionIfGetAllFeatureSetsGivenVersionWithComparator() { + public void shouldThrowExceptionIfGetAllFeatureSetsGivenVersionWithComparator() + throws InvalidProtocolBufferException { expectedException.expect(IllegalArgumentException.class); specService.listFeatureSets( Filter.newBuilder() @@ -270,7 +290,8 @@ public void shouldThrowExceptionIfGetAllFeatureSetsGivenVersionWithComparator() } @Test - public void shouldGetLatestFeatureSetGivenMissingVersionFilter() { + public void shouldGetLatestFeatureSetGivenMissingVersionFilter() + throws InvalidProtocolBufferException { GetFeatureSetResponse actual = specService.getFeatureSet( GetFeatureSetRequest.newBuilder().setName("f1").setProject("project1").build()); @@ -279,7 +300,8 @@ public void shouldGetLatestFeatureSetGivenMissingVersionFilter() { } @Test - public void shouldGetSpecificFeatureSetGivenSpecificVersionFilter() { + public void shouldGetSpecificFeatureSetGivenSpecificVersionFilter() + throws InvalidProtocolBufferException { when(featureSetRepository.findFeatureSetByNameAndProject_NameAndVersion("f1", "project1", 2)) .thenReturn(featureSets.get(1)); GetFeatureSetResponse actual = @@ -294,14 +316,15 @@ public void shouldGetSpecificFeatureSetGivenSpecificVersionFilter() { } @Test - public void shouldThrowExceptionGivenMissingFeatureSetName() { + public void shouldThrowExceptionGivenMissingFeatureSetName() + throws InvalidProtocolBufferException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("No feature set name provided"); specService.getFeatureSet(GetFeatureSetRequest.newBuilder().setVersion(2).build()); } @Test - public void shouldThrowExceptionGivenMissingFeatureSet() { + public void shouldThrowExceptionGivenMissingFeatureSet() throws InvalidProtocolBufferException { expectedException.expect(RetrievalException.class); expectedException.expectMessage( "Feature set with name \"f1000\" and version \"2\" could not be found."); @@ -314,7 +337,8 @@ public void shouldThrowExceptionGivenMissingFeatureSet() { } @Test - public void shouldThrowRetrievalExceptionGivenInvalidFeatureSetVersionComparator() { + public void shouldThrowRetrievalExceptionGivenInvalidFeatureSetVersionComparator() + throws InvalidProtocolBufferException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage( "Invalid listFeatureSetRequest. Version must be set to \"*\" if the project name and feature set name aren't set explicitly: \n" @@ -361,7 +385,8 @@ public void shouldThrowRetrievalExceptionIfNoStoresFoundWithName() { } @Test - public void applyFeatureSetShouldReturnFeatureSetWithLatestVersionIfFeatureSetHasNotChanged() { + public void applyFeatureSetShouldReturnFeatureSetWithLatestVersionIfFeatureSetHasNotChanged() + throws InvalidProtocolBufferException { FeatureSetSpec incomingFeatureSetSpec = featureSets.get(2).toProto().getSpec().toBuilder().clearVersion().build(); @@ -375,9 +400,10 @@ public void applyFeatureSetShouldReturnFeatureSetWithLatestVersionIfFeatureSetHa } @Test - public void applyFeatureSetShouldApplyFeatureSetWithInitVersionIfNotExists() { + public void applyFeatureSetShouldApplyFeatureSetWithInitVersionIfNotExists() + throws InvalidProtocolBufferException { when(featureSetRepository.findAllByNameLikeAndProject_NameOrderByNameAscVersionAsc( - "f2", "project1")) + "f2", "project1")) .thenReturn(Lists.newArrayList()); FeatureSetProto.FeatureSet incomingFeatureSet = @@ -408,7 +434,8 @@ public void applyFeatureSetShouldApplyFeatureSetWithInitVersionIfNotExists() { } @Test - public void applyFeatureSetShouldIncrementFeatureSetVersionIfAlreadyExists() { + public void applyFeatureSetShouldIncrementFeatureSetVersionIfAlreadyExists() + throws InvalidProtocolBufferException { FeatureSetProto.FeatureSet incomingFeatureSet = featureSets.get(2).toProto(); incomingFeatureSet = incomingFeatureSet @@ -450,21 +477,22 @@ public void applyFeatureSetShouldIncrementFeatureSetVersionIfAlreadyExists() { } @Test - public void applyFeatureSetShouldNotCreateFeatureSetIfFieldsUnordered() { + public void applyFeatureSetShouldNotCreateFeatureSetIfFieldsUnordered() + throws InvalidProtocolBufferException { Field f3f1 = new Field("f3f1", Enum.INT64); Field f3f2 = new Field("f3f2", Enum.INT64); Field f3e1 = new Field("f3e1", Enum.STRING); FeatureSetProto.FeatureSet incomingFeatureSet = (new FeatureSet( - "f3", - "project1", - 5, - 100L, - Arrays.asList(f3e1), - Arrays.asList(f3f2, f3f1), - defaultSource, - FeatureSetStatus.STATUS_READY)) + "f3", + "project1", + 5, + 100L, + Arrays.asList(f3e1), + Arrays.asList(f3f2, f3f1), + defaultSource, + FeatureSetStatus.STATUS_READY)) .toProto(); ApplyFeatureSetResponse applyFeatureSetResponse = @@ -481,6 +509,108 @@ public void applyFeatureSetShouldNotCreateFeatureSetIfFieldsUnordered() { equalTo(incomingFeatureSet.getSpec().getName())); } + @Test + public void applyFeatureSetShouldAcceptPresenceShapeAndDomainConstraints() + throws InvalidProtocolBufferException { + List entitySpecs = new ArrayList<>(); + entitySpecs.add(EntitySpec.newBuilder().setName("entity1") + .setValueType(Enum.INT64) + .setPresence(FeaturePresence.getDefaultInstance()) + .setShape(FixedShape.getDefaultInstance()) + .setDomain("mydomain") + .build()); + entitySpecs.add(EntitySpec.newBuilder().setName("entity2") + .setValueType(Enum.INT64) + .setGroupPresence(FeaturePresenceWithinGroup.getDefaultInstance()) + .setValueCount(ValueCount.getDefaultInstance()) + .setIntDomain(IntDomain.getDefaultInstance()) + .build()); + entitySpecs.add(EntitySpec.newBuilder().setName("entity3") + .setValueType(Enum.FLOAT) + .setPresence(FeaturePresence.getDefaultInstance()) + .setValueCount(ValueCount.getDefaultInstance()) + .setFloatDomain(FloatDomain.getDefaultInstance()) + .build()); + entitySpecs.add(EntitySpec.newBuilder().setName("entity4") + .setValueType(Enum.STRING) + .setPresence(FeaturePresence.getDefaultInstance()) + .setValueCount(ValueCount.getDefaultInstance()) + .setStringDomain(StringDomain.getDefaultInstance()) + .build()); + entitySpecs.add(EntitySpec.newBuilder().setName("entity5") + .setValueType(Enum.BOOL) + .setPresence(FeaturePresence.getDefaultInstance()) + .setValueCount(ValueCount.getDefaultInstance()) + .setBoolDomain(BoolDomain.getDefaultInstance()) + .build()); + + List featureSpecs = new ArrayList<>(); + featureSpecs.add(FeatureSpec.newBuilder().setName("feature1") + .setValueType(Enum.INT64) + .setPresence(FeaturePresence.getDefaultInstance()) + .setShape(FixedShape.getDefaultInstance()) + .setDomain("mydomain") + .build()); + featureSpecs.add(FeatureSpec.newBuilder().setName("feature2") + .setValueType(Enum.INT64) + .setGroupPresence(FeaturePresenceWithinGroup.getDefaultInstance()) + .setValueCount(ValueCount.getDefaultInstance()) + .setIntDomain(IntDomain.getDefaultInstance()) + .build()); + featureSpecs.add(FeatureSpec.newBuilder().setName("feature3") + .setValueType(Enum.FLOAT) + .setPresence(FeaturePresence.getDefaultInstance()) + .setValueCount(ValueCount.getDefaultInstance()) + .setFloatDomain(FloatDomain.getDefaultInstance()) + .build()); + featureSpecs.add(FeatureSpec.newBuilder().setName("feature4") + .setValueType(Enum.STRING) + .setPresence(FeaturePresence.getDefaultInstance()) + .setValueCount(ValueCount.getDefaultInstance()) + .setStringDomain(StringDomain.getDefaultInstance()) + .build()); + featureSpecs.add(FeatureSpec.newBuilder().setName("feature5") + .setValueType(Enum.BOOL) + .setPresence(FeaturePresence.getDefaultInstance()) + .setValueCount(ValueCount.getDefaultInstance()) + .setBoolDomain(BoolDomain.getDefaultInstance()) + .build()); + + FeatureSetSpec featureSetSpec = FeatureSetSpec.newBuilder() + .setProject("project1") + .setName("featureSetWithConstraints") + .addAllEntities(entitySpecs) + .addAllFeatures(featureSpecs) + .build(); + FeatureSetProto.FeatureSet featureSet = FeatureSetProto.FeatureSet.newBuilder() + .setSpec(featureSetSpec) + .build(); + + ApplyFeatureSetResponse applyFeatureSetResponse = specService.applyFeatureSet(featureSet); + FeatureSetSpec appliedFeatureSetSpec = applyFeatureSetResponse.getFeatureSet().getSpec(); + + // appliedEntitySpecs needs to be sorted because the list returned by specService may not + // follow the order in the request + List appliedEntitySpecs = new ArrayList<>(appliedFeatureSetSpec.getEntitiesList()); + appliedEntitySpecs.sort(Comparator.comparing(EntitySpec::getName)); + + // appliedFeatureSpecs needs to be sorted because the list returned by specService may not + // follow the order in the request + List appliedFeatureSpecs = new ArrayList<>(appliedFeatureSetSpec.getFeaturesList()); + appliedFeatureSpecs.sort(Comparator.comparing(FeatureSpec::getName)); + + assertEquals(appliedEntitySpecs.size(), entitySpecs.size()); + assertEquals(appliedFeatureSpecs.size(), featureSpecs.size()); + + for (int i = 0; i < appliedEntitySpecs.size(); i++) { + assertEquals(entitySpecs.get(i), appliedEntitySpecs.get(i)); + } + + for (int i = 0; i < appliedFeatureSpecs.size(); i++) { + assertEquals(featureSpecs.get(i), appliedFeatureSpecs.get(i)); + } + } + @Test public void shouldUpdateStoreIfConfigChanges() throws InvalidProtocolBufferException { when(storeRepository.findById("SERVING")).thenReturn(Optional.of(stores.get(0))); @@ -521,7 +651,7 @@ public void shouldDoNothingIfNoChange() throws InvalidProtocolBufferException { } @Test - public void shouldFailIfGetFeatureSetWithoutProject() { + public void shouldFailIfGetFeatureSetWithoutProject() throws InvalidProtocolBufferException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("No project provided"); specService.getFeatureSet(GetFeatureSetRequest.newBuilder().setName("f1").build()); @@ -530,6 +660,7 @@ public void shouldFailIfGetFeatureSetWithoutProject() { private FeatureSet newDummyFeatureSet(String name, int version, String project) { Field feature = new Field("feature", Enum.INT64); Field entity = new Field("entity", Enum.STRING); + FeatureSet fs = new FeatureSet( name, @@ -553,4 +684,6 @@ private Store newDummyStore(String name) { store.setConfig(RedisConfig.newBuilder().setPort(6379).build().toByteArray()); return store; } + + } diff --git a/datatypes/java/src/main/proto/tensorflow_metadata b/datatypes/java/src/main/proto/tensorflow_metadata new file mode 120000 index 0000000000..a633bb850f --- /dev/null +++ b/datatypes/java/src/main/proto/tensorflow_metadata @@ -0,0 +1 @@ +../../../../../protos/tensorflow_metadata \ No newline at end of file diff --git a/protos/feast/core/FeatureSet.proto b/protos/feast/core/FeatureSet.proto index 910cc375f7..429d99c854 100644 --- a/protos/feast/core/FeatureSet.proto +++ b/protos/feast/core/FeatureSet.proto @@ -24,6 +24,7 @@ import "feast/types/Value.proto"; import "feast/core/Source.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/timestamp.proto"; +import "tensorflow_metadata/proto/v0/schema.proto"; message FeatureSet { // User-specified specifications of this feature set. @@ -67,6 +68,46 @@ message EntitySpec { // Value type of the feature. feast.types.ValueType.Enum value_type = 2; + + // presence_constraints, shape_type and domain_info are referenced from: + // https://github.com/tensorflow/metadata/blob/36f65d1268cbc92cdbcf812ee03dcf47fb53b91e/tensorflow_metadata/proto/v0/schema.proto#L107 + + oneof presence_constraints { + // Constraints on the presence of this feature in the examples. + tensorflow.metadata.v0.FeaturePresence presence = 3; + // Only used in the context of a "group" context, e.g., inside a sequence. + tensorflow.metadata.v0.FeaturePresenceWithinGroup group_presence = 4; + } + + // The shape of the feature which governs the number of values that appear in + // each example. + oneof shape_type { + // The feature has a fixed shape corresponding to a multi-dimensional + // tensor. + tensorflow.metadata.v0.FixedShape shape = 5; + // The feature doesn't have a well defined shape. All we know are limits on + // the minimum and maximum number of values. + tensorflow.metadata.v0.ValueCount value_count = 6; + } + + // Domain for the values of the feature. + oneof domain_info { + // Reference to a domain defined at the schema level. + string domain = 7; + // Inline definitions of domains. + tensorflow.metadata.v0.IntDomain int_domain = 8; + tensorflow.metadata.v0.FloatDomain float_domain = 9; + tensorflow.metadata.v0.StringDomain string_domain = 10; + tensorflow.metadata.v0.BoolDomain bool_domain = 11; + tensorflow.metadata.v0.StructDomain struct_domain = 12; + // Supported semantic domains. + tensorflow.metadata.v0.NaturalLanguageDomain natural_language_domain = 13; + tensorflow.metadata.v0.ImageDomain image_domain = 14; + tensorflow.metadata.v0.MIDDomain mid_domain = 15; + tensorflow.metadata.v0.URLDomain url_domain = 16; + tensorflow.metadata.v0.TimeDomain time_domain = 17; + tensorflow.metadata.v0.TimeOfDayDomain time_of_day_domain = 18; + } } message FeatureSpec { @@ -75,6 +116,46 @@ message FeatureSpec { // Value type of the feature. feast.types.ValueType.Enum value_type = 2; + + // presence_constraints, shape_type and domain_info are referenced from: + // https://github.com/tensorflow/metadata/blob/36f65d1268cbc92cdbcf812ee03dcf47fb53b91e/tensorflow_metadata/proto/v0/schema.proto#L107 + + oneof presence_constraints { + // Constraints on the presence of this feature in the examples. + tensorflow.metadata.v0.FeaturePresence presence = 3; + // Only used in the context of a "group" context, e.g., inside a sequence. + tensorflow.metadata.v0.FeaturePresenceWithinGroup group_presence = 4; + } + + // The shape of the feature which governs the number of values that appear in + // each example. + oneof shape_type { + // The feature has a fixed shape corresponding to a multi-dimensional + // tensor. + tensorflow.metadata.v0.FixedShape shape = 5; + // The feature doesn't have a well defined shape. All we know are limits on + // the minimum and maximum number of values. + tensorflow.metadata.v0.ValueCount value_count = 6; + } + + // Domain for the values of the feature. + oneof domain_info { + // Reference to a domain defined at the schema level. + string domain = 7; + // Inline definitions of domains. + tensorflow.metadata.v0.IntDomain int_domain = 8; + tensorflow.metadata.v0.FloatDomain float_domain = 9; + tensorflow.metadata.v0.StringDomain string_domain = 10; + tensorflow.metadata.v0.BoolDomain bool_domain = 11; + tensorflow.metadata.v0.StructDomain struct_domain = 12; + // Supported semantic domains. + tensorflow.metadata.v0.NaturalLanguageDomain natural_language_domain = 13; + tensorflow.metadata.v0.ImageDomain image_domain = 14; + tensorflow.metadata.v0.MIDDomain mid_domain = 15; + tensorflow.metadata.v0.URLDomain url_domain = 16; + tensorflow.metadata.v0.TimeDomain time_domain = 17; + tensorflow.metadata.v0.TimeOfDayDomain time_of_day_domain = 18; + } } message FeatureSetMeta { diff --git a/protos/tensorflow_metadata/proto/v0/path.proto b/protos/tensorflow_metadata/proto/v0/path.proto new file mode 100644 index 0000000000..cac09b7a08 --- /dev/null +++ b/protos/tensorflow_metadata/proto/v0/path.proto @@ -0,0 +1,43 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +syntax = "proto2"; +option cc_enable_arenas = true; + +package tensorflow.metadata.v0; + +option java_package = "org.tensorflow.metadata.v0"; +option java_multiple_files = true; + +// A path is a more general substitute for the name of a field or feature that +// can be used for flat examples as well as structured data. For example, if +// we had data in a protocol buffer: +// message Person { +// int age = 1; +// optional string gender = 2; +// repeated Person parent = 3; +// } +// Thus, here the path {step:["parent", "age"]} in statistics would refer to the +// age of a parent, and {step:["parent", "parent", "age"]} would refer to the +// age of a grandparent. This allows us to distinguish between the statistics +// of parents' ages and grandparents' ages. In general, repeated messages are +// to be preferred to linked lists of arbitrary length. +// For SequenceExample, if we have a feature list "foo", this is represented +// by {step:["##SEQUENCE##", "foo"]}. +message Path { + // Any string is a valid step. + // However, whenever possible have a step be [A-Za-z0-9_]+. + repeated string step = 1; +} diff --git a/protos/tensorflow_metadata/proto/v0/schema.proto b/protos/tensorflow_metadata/proto/v0/schema.proto new file mode 100644 index 0000000000..ce30515c69 --- /dev/null +++ b/protos/tensorflow_metadata/proto/v0/schema.proto @@ -0,0 +1,672 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +syntax = "proto2"; + +package tensorflow.metadata.v0; + +import "google/protobuf/any.proto"; +import "tensorflow_metadata/proto/v0/path.proto"; + +option cc_enable_arenas = true; +option java_package = "org.tensorflow.metadata.v0"; +option java_multiple_files = true; + +// LifecycleStage. Only UNKNOWN_STAGE, BETA, and PRODUCTION features are +// actually validated. +// PLANNED, ALPHA, and DEBUG are treated as DEPRECATED. +enum LifecycleStage { + UNKNOWN_STAGE = 0; // Unknown stage. + PLANNED = 1; // Planned feature, may not be created yet. + ALPHA = 2; // Prototype feature, not used in experiments yet. + BETA = 3; // Used in user-facing experiments. + PRODUCTION = 4; // Used in a significant fraction of user traffic. + DEPRECATED = 5; // No longer supported: do not use in new models. + DEBUG_ONLY = 6; // Only exists for debugging purposes. +} + +// +// Message to represent schema information. +// NextID: 14 +message Schema { + // Features described in this schema. + repeated Feature feature = 1; + + // Sparse features described in this schema. + repeated SparseFeature sparse_feature = 6; + + // Weighted features described in this schema. + repeated WeightedFeature weighted_feature = 12; + + // Use StructDomain instead. + // Sequences described in this schema. A sequence may be described in terms of + // several features. Any features appearing within a sequence must *not* be + // declared as top-level features in . +// GOOGLE-LEGACY repeated Sequence sequence = 2; + + // declared as top-level features in . + // String domains referenced in the features. + repeated StringDomain string_domain = 4; + + // top level float domains that can be reused by features + repeated FloatDomain float_domain = 9; + + // top level int domains that can be reused by features + repeated IntDomain int_domain = 10; + + // Default environments for each feature. + // An environment represents both a type of location (e.g. a server or phone) + // and a time (e.g. right before model X is run). In the standard scenario, + // 99% of the features should be in the default environments TRAINING, + // SERVING, and the LABEL (or labels) AND WEIGHT is only available at TRAINING + // (not at serving). + // Other possible variations: + // 1. There may be TRAINING_MOBILE, SERVING_MOBILE, TRAINING_SERVICE, + // and SERVING_SERVICE. + // 2. If one is ensembling three models, where the predictions of the first + // three models are available for the ensemble model, there may be + // TRAINING, SERVING_INITIAL, SERVING_ENSEMBLE. + // See FeatureProto::not_in_environment and FeatureProto::in_environment. + repeated string default_environment = 5; + + /* BEGIN GOOGLE-LEGACY + // TODO(b/73109633): Change default to false, before removing this field. + optional bool generate_legacy_feature_spec = 7 [default = true]; + END GOOGLE-LEGACY */ + + // Additional information about the schema as a whole. Features may also + // be annotated individually. + optional Annotation annotation = 8; + + // Dataset-level constraints. This is currently used for specifying + // information about changes in num_examples. + optional DatasetConstraints dataset_constraints = 11; + + // TensorRepresentation groups. The keys are the names of the groups. + // Key "" (empty string) denotes the "default" group, which is what should + // be used when a group name is not provided. + // See the documentation at TensorRepresentationGroup for more info. + // Under development. DO NOT USE. + map tensor_representation_group = 13; +} + +// Describes schema-level information about a specific feature. +// NextID: 31 +message Feature { + // The name of the feature. + optional string name = 1; // required + + // This field is no longer supported. Instead, use: + // lifecycle_stage: DEPRECATED + // TODO(b/111450258): remove this. + optional bool deprecated = 2 [deprecated = true]; + + // Comment field for a human readable description of the field. + // TODO(b/123518108): remove this. +// GOOGLE-LEGACY optional string comment = 3 [deprecated = true]; + + oneof presence_constraints { + // Constraints on the presence of this feature in the examples. + FeaturePresence presence = 14; + // Only used in the context of a "group" context, e.g., inside a sequence. + FeaturePresenceWithinGroup group_presence = 17; + } + + // The shape of the feature which governs the number of values that appear in + // each example. + oneof shape_type { + // The feature has a fixed shape corresponding to a multi-dimensional + // tensor. + FixedShape shape = 23; + // The feature doesn't have a well defined shape. All we know are limits on + // the minimum and maximum number of values. + ValueCount value_count = 5; + } + + // Physical type of the feature's values. + // Note that you can have: + // type: BYTES + // int_domain: { + // min: 0 + // max: 3 + // } + // This would be a field that is syntactically BYTES (i.e. strings), but + // semantically an int, i.e. it would be "0", "1", "2", or "3". + optional FeatureType type = 6; + + // Domain for the values of the feature. + oneof domain_info { + // Reference to a domain defined at the schema level. + string domain = 7; + // Inline definitions of domains. + IntDomain int_domain = 9; + FloatDomain float_domain = 10; + StringDomain string_domain = 11; + BoolDomain bool_domain = 13; + StructDomain struct_domain = 29; + // Supported semantic domains. + NaturalLanguageDomain natural_language_domain = 24; + ImageDomain image_domain = 25; + MIDDomain mid_domain = 26; + URLDomain url_domain = 27; + TimeDomain time_domain = 28; + TimeOfDayDomain time_of_day_domain = 30; + } + + // Constraints on the distribution of the feature values. + // Currently only supported for StringDomains. + // TODO(b/69473628): Extend functionality to other domain types. + optional DistributionConstraints distribution_constraints = 15; + + // Additional information about the feature for documentation purpose. + optional Annotation annotation = 16; + + // Tests comparing the distribution to the associated serving data. + optional FeatureComparator skew_comparator = 18; + + // Tests comparing the distribution between two consecutive spans (e.g. days). + optional FeatureComparator drift_comparator = 21; + + // List of environments this feature is present in. + // Should be disjoint from not_in_environment. + // This feature is in environment "foo" if: + // ("foo" is in in_environment or default_environments) AND + // "foo" is not in not_in_environment. + // See Schema::default_environments. + repeated string in_environment = 20; + + // List of environments this feature is not present in. + // Should be disjoint from of in_environment. + // See Schema::default_environments and in_environment. + repeated string not_in_environment = 19; + + // The lifecycle stage of a feature. It can also apply to its descendants. + // i.e., if a struct is DEPRECATED, its children are implicitly deprecated. + optional LifecycleStage lifecycle_stage = 22; +} + +// Additional information about the schema or about a feature. +message Annotation { + // Tags can be used to mark features. For example, tag on user_age feature can + // be `user_feature`, tag on user_country feature can be `location_feature`, + // `user_feature`. + repeated string tag = 1; + // Free-text comments. This can be used as a description of the feature, + // developer notes etc. + repeated string comment = 2; + // Application-specific metadata may be attached here. + repeated .google.protobuf.Any extra_metadata = 3; +} + +// Checks that the ratio of the current value to the previous value is not below +// the min_fraction_threshold or above the max_fraction_threshold. That is, +// previous value * min_fraction_threshold <= current value <= +// previous value * max_fraction_threshold. +// To specify that the value cannot change, set both min_fraction_threshold and +// max_fraction_threshold to 1.0. +message NumericValueComparator { + optional double min_fraction_threshold = 1; + optional double max_fraction_threshold = 2; +} + +// Constraints on the entire dataset. +message DatasetConstraints { + // Tests differences in number of examples between the current data and the + // previous span. + optional NumericValueComparator num_examples_drift_comparator = 1; + // Tests comparisions in number of examples between the current data and the + // previous version of that data. + optional NumericValueComparator num_examples_version_comparator = 2; + // Minimum number of examples in the dataset. + optional int64 min_examples_count = 3; +} + +// Specifies a fixed shape for the feature's values. The immediate implication +// is that each feature has a fixed number of values. Moreover, these values +// can be parsed in a multi-dimensional tensor using the specified axis sizes. +// The FixedShape defines a lexicographical ordering of the data. For instance, +// if there is a FixedShape { +// dim {size:3} dim {size:2} +// } +// then tensor[0][0]=field[0] +// then tensor[0][1]=field[1] +// then tensor[1][0]=field[2] +// then tensor[1][1]=field[3] +// then tensor[2][0]=field[4] +// then tensor[2][1]=field[5] +// +// The FixedShape message is identical with the TensorFlow TensorShape proto +// message. +message FixedShape { + // The dimensions that define the shape. The total number of values in each + // example is the product of sizes of each dimension. + repeated Dim dim = 2; + + // An axis in a multi-dimensional feature representation. + message Dim { + optional int64 size = 1; + + // Optional name of the tensor dimension. + optional string name = 2; + } +} + +// Limits on maximum and minimum number of values in a +// single example (when the feature is present). Use this when the minimum +// value count can be different than the maximum value count. Otherwise prefer +// FixedShape. +message ValueCount { + optional int64 min = 1; + optional int64 max = 2; +} + +/* BEGIN GOOGLE-LEGACY +// Constraint on the number of elements in a sequence. +message LengthConstraint { + optional int64 min = 1; + optional int64 max = 2; +} + +// A sequence is a logical feature that comprises several "raw" features that +// encode values at different "steps" within the sequence. +// TODO(b/110490010): Delete this. This is a special case of StructDomain. +message Sequence { + // An optional name for this sequence. Used mostly for debugging and + // presentation. + optional string name = 1; + + // Features that comprise the sequence. These features are "zipped" together + // to form the values for the sequence at different steps. + // - Use group_presence within each feature to encode presence constraints + // within the sequence. + // - If all features have the same value-count constraints then + // declare this once using the shape_constraint below. + repeated Feature feature = 2; + + // Constraints on the presence of the sequence across all examples in the + // dataset. The sequence is assumed to be present if at least one of its + // features is present. + optional FeaturePresence presence = 3; + + // Shape constraints that apply on all the features that comprise the + // sequence. If this is set then the value_count in 'feature' is + // ignored. + // TODO(martinz): delete: there is no reason to believe the shape of the + // fields in a sequence will be the same. Use the fields in Feature instead. + oneof shape_constraint { + ValueCount value_count = 4; + FixedShape fixed_shape = 5; + } + + // Constraint on the number of elements in a sequence. + optional LengthConstraint length_constraint = 6; +} +END GOOGLE-LEGACY */ + +// Represents a weighted feature that is encoded as a combination of raw base +// features. The `weight_feature` should be a float feature with identical +// shape as the `feature`. This is useful for representing weights associated +// with categorical tokens (e.g. a TFIDF weight associated with each token). +// TODO(b/142122960): Handle WeightedCategorical end to end in TFX (validation, +// TFX Unit Testing, etc) +message WeightedFeature { + // Name for the weighted feature. This should not clash with other features in + // the same schema. + optional string name = 1; // required + // Path of a base feature to be weighted. Required. + optional Path feature = 2; + // Path of weight feature to associate with the base feature. Must be same + // shape as feature. Required. + optional Path weight_feature = 3; + // The lifecycle_stage determines where a feature is expected to be used, + // and therefore how important issues with it are. + optional LifecycleStage lifecycle_stage = 4; +} + +// A sparse feature represents a sparse tensor that is encoded with a +// combination of raw features, namely index features and a value feature. Each +// index feature defines a list of indices in a different dimension. +message SparseFeature { + reserved 11; + // Name for the sparse feature. This should not clash with other features in + // the same schema. + optional string name = 1; // required + + // This field is no longer supported. Instead, use: + // lifecycle_stage: DEPRECATED + // TODO(b/111450258): remove this. + optional bool deprecated = 2 [deprecated = true]; + + // The lifecycle_stage determines where a feature is expected to be used, + // and therefore how important issues with it are. + optional LifecycleStage lifecycle_stage = 7; + + // Comment field for a human readable description of the field. + // TODO(martinz): delete, convert to annotation. +// GOOGLE-LEGACY optional string comment = 3 [deprecated = true]; + + // Constraints on the presence of this feature in examples. + // Deprecated, this is inferred by the referred features. + optional FeaturePresence presence = 4 [deprecated = true]; + + // Shape of the sparse tensor that this SparseFeature represents. + // Currently not supported. + // TODO(b/109669962): Consider deriving this from the referred features. + optional FixedShape dense_shape = 5; + + // Features that represent indexes. Should be integers >= 0. + repeated IndexFeature index_feature = 6; // at least one + message IndexFeature { + // Name of the index-feature. This should be a reference to an existing + // feature in the schema. + optional string name = 1; + } + + // If true then the index values are already sorted lexicographically. + optional bool is_sorted = 8; + + optional ValueFeature value_feature = 9; // required + message ValueFeature { + // Name of the value-feature. This should be a reference to an existing + // feature in the schema. + optional string name = 1; + } + + // Type of value feature. + // Deprecated, this is inferred by the referred features. + optional FeatureType type = 10 [deprecated = true]; +} + +// Models constraints on the distribution of a feature's values. +// TODO(martinz): replace min_domain_mass with max_off_domain (but slowly). +message DistributionConstraints { + // The minimum fraction (in [0,1]) of values across all examples that + // should come from the feature's domain, e.g.: + // 1.0 => All values must come from the domain. + // .9 => At least 90% of the values must come from the domain. + optional double min_domain_mass = 1 [default = 1.0]; +} + +// Encodes information for domains of integer values. +// Note that FeatureType could be either INT or BYTES. +message IntDomain { + // Id of the domain. Required if the domain is defined at the schema level. If + // so, then the name must be unique within the schema. + optional string name = 1; + + // Min and max values for the domain. + optional int64 min = 3; + optional int64 max = 4; + + // If true then the domain encodes categorical values (i.e., ids) rather than + // ordinal values. + optional bool is_categorical = 5; +} + +// Encodes information for domains of float values. +// Note that FeatureType could be either INT or BYTES. +message FloatDomain { + // Id of the domain. Required if the domain is defined at the schema level. If + // so, then the name must be unique within the schema. + optional string name = 1; + + // Min and max values of the domain. + optional float min = 3; + optional float max = 4; +} + +// Domain for a recursive struct. +// NOTE: If a feature with a StructDomain is deprecated, then all the +// child features (features and sparse_features of the StructDomain) are also +// considered to be deprecated. Similarly child features can only be in +// environments of the parent feature. +message StructDomain { + repeated Feature feature = 1; + + repeated SparseFeature sparse_feature = 2; +} + +// Encodes information for domains of string values. +message StringDomain { + // Id of the domain. Required if the domain is defined at the schema level. If + // so, then the name must be unique within the schema. + optional string name = 1; + + // The values appearing in the domain. + repeated string value = 2; +} + +// Encodes information about the domain of a boolean attribute that encodes its +// TRUE/FALSE values as strings, or 0=false, 1=true. +// Note that FeatureType could be either INT or BYTES. +message BoolDomain { + // Id of the domain. Required if the domain is defined at the schema level. If + // so, then the name must be unique within the schema. + optional string name = 1; + + // Strings values for TRUE/FALSE. + optional string true_value = 2; + optional string false_value = 3; +} + +// BEGIN SEMANTIC-TYPES-PROTOS +// Semantic domains are specialized feature domains. For example a string +// Feature might represent a Time of a specific format. +// Semantic domains are defined as protocol buffers to allow further sub-types / +// specialization, e.g: NaturalLanguageDomain can provide information on the +// language of the text. + +// Natural language text. +message NaturalLanguageDomain {} + +// Image data. +message ImageDomain {} + +// Knowledge graph ID, see: https://www.wikidata.org/wiki/Property:P646 +message MIDDomain {} + +// A URL, see: https://en.wikipedia.org/wiki/URL +message URLDomain {} + +// Time or date representation. +message TimeDomain { + enum IntegerTimeFormat { + FORMAT_UNKNOWN = 0; + UNIX_DAYS = 5; // Number of days since 1970-01-01. + UNIX_SECONDS = 1; + UNIX_MILLISECONDS = 2; + UNIX_MICROSECONDS = 3; + UNIX_NANOSECONDS = 4; + } + + oneof format { + // Expected format that contains a combination of regular characters and + // special format specifiers. Format specifiers are a subset of the + // strptime standard. + string string_format = 1; + + // Expected format of integer times. + IntegerTimeFormat integer_format = 2; + } +} + +// Time of day, without a particular date. +message TimeOfDayDomain { + enum IntegerTimeOfDayFormat { + FORMAT_UNKNOWN = 0; + // Time values, containing hour/minute/second/nanos, encoded into 8-byte + // bit fields following the ZetaSQL convention: + // 6 5 4 3 2 1 + // MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + // | H || M || S ||---------- nanos -----------| + PACKED_64_NANOS = 1; + } + + oneof format { + // Expected format that contains a combination of regular characters and + // special format specifiers. Format specifiers are a subset of the + // strptime standard. + string string_format = 1; + + // Expected format of integer times. + IntegerTimeOfDayFormat integer_format = 2; + } +} +// END SEMANTIC-TYPES-PROTOS + +// Describes the physical representation of a feature. +// It may be different than the logical representation, which +// is represented as a Domain. +enum FeatureType { + TYPE_UNKNOWN = 0; + BYTES = 1; + INT = 2; + FLOAT = 3; + STRUCT = 4; +} + +// Describes constraints on the presence of the feature in the data. +message FeaturePresence { + // Minimum fraction of examples that have this feature. + optional double min_fraction = 1; + // Minimum number of examples that have this feature. + optional int64 min_count = 2; +} + +// Records constraints on the presence of a feature inside a "group" context +// (e.g., .presence inside a group of features that define a sequence). +message FeaturePresenceWithinGroup { + optional bool required = 1; +} + +// Checks that the L-infinity norm is below a certain threshold between the +// two discrete distributions. Since this is applied to a FeatureNameStatistics, +// it only considers the top k. +// L_infty(p,q) = max_i |p_i-q_i| +message InfinityNorm { + // The InfinityNorm is in the interval [0.0, 1.0] so sensible bounds should + // be in the interval [0.0, 1.0). + optional double threshold = 1; +} + +message FeatureComparator { + optional InfinityNorm infinity_norm = 1; +} + +// A TensorRepresentation captures the intent for converting columns in a +// dataset to TensorFlow Tensors (or more generally, tf.CompositeTensors). +// Note that one tf.CompositeTensor may consist of data from multiple columns, +// for example, a N-dimensional tf.SparseTensor may need N + 1 columns to +// provide the sparse indices and values. +// Note that the "column name" that a TensorRepresentation needs is a +// string, not a Path -- it means that the column name identifies a top-level +// Feature in the schema (i.e. you cannot specify a Feature nested in a STRUCT +// Feature). +message TensorRepresentation { + message DefaultValue { + oneof kind { + double float_value = 1; + // Note that the data column might be of a shorter integral type. It's the + // user's responsitiblity to make sure the default value fits that type. + int64 int_value = 2; + bytes bytes_value = 3; + // uint_value should only be used if the default value can't fit in a + // int64 (`int_value`). + uint64 uint_value = 4; + } + } + + // A tf.Tensor + message DenseTensor { + // Identifies the column in the dataset that provides the values of this + // Tensor. + optional string column_name = 1; + // The shape of each row of the data (i.e. does not include the batch + // dimension) + optional FixedShape shape = 2; + // If this column is missing values in a row, the default_value will be + // used to fill that row. + optional DefaultValue default_value = 3; + } + + // A ragged tf.SparseTensor that models nested lists. + message VarLenSparseTensor { + // Identifies the column in the dataset that should be converted to the + // VarLenSparseTensor. + optional string column_name = 1; + } + + // A tf.SparseTensor whose indices and values come from separate data columns. + // This will replace Schema.sparse_feature eventually. + // The index columns must be of INT type, and all the columns must co-occur + // and have the same valency at the same row. + message SparseTensor { + // The dense shape of the resulting SparseTensor (does not include the batch + // dimension). + optional FixedShape dense_shape = 1; + // The columns constitute the coordinates of the values. + // indices_column[i][j] contains the coordinate of the i-th dimension of the + // j-th value. + repeated string index_column_names = 2; + // The column that contains the values. + optional string value_column_name = 3; + } + + oneof kind { + DenseTensor dense_tensor = 1; + VarLenSparseTensor varlen_sparse_tensor = 2; + SparseTensor sparse_tensor = 3; + } +} + +// A TensorRepresentationGroup is a collection of TensorRepresentations with +// names. These names may serve as identifiers when converting the dataset +// to a collection of Tensors or tf.CompositeTensors. +// For example, given the following group: +// { +// key: "dense_tensor" +// tensor_representation { +// dense_tensor { +// column_name: "univalent_feature" +// shape { +// dim { +// size: 1 +// } +// } +// default_value { +// float_value: 0 +// } +// } +// } +// } +// { +// key: "varlen_sparse_tensor" +// tensor_representation { +// varlen_sparse_tensor { +// column_name: "multivalent_feature" +// } +// } +// } +// +// Then the schema is expected to have feature "univalent_feature" and +// "multivalent_feature", and when a batch of data is converted to Tensors using +// this TensorRepresentationGroup, the result may be the following dict: +// { +// "dense_tensor": tf.Tensor(...), +// "varlen_sparse_tensor": tf.SparseTensor(...), +// } +message TensorRepresentationGroup { + map tensor_representation = 1; +}