diff --git a/go/embedded/online_features.go b/go/embedded/online_features.go index d019bb72ac..7cd1e4ed81 100644 --- a/go/embedded/online_features.go +++ b/go/embedded/online_features.go @@ -8,23 +8,22 @@ import ( "os" "os/signal" "syscall" - - "google.golang.org/grpc" - - "github.com/feast-dev/feast/go/internal/feast/server" - "github.com/feast-dev/feast/go/internal/feast/server/logging" - "github.com/feast-dev/feast/go/protos/feast/serving" + "time" "github.com/apache/arrow/go/v8/arrow" "github.com/apache/arrow/go/v8/arrow/array" "github.com/apache/arrow/go/v8/arrow/cdata" "github.com/apache/arrow/go/v8/arrow/memory" + "google.golang.org/grpc" "github.com/feast-dev/feast/go/internal/feast" "github.com/feast-dev/feast/go/internal/feast/model" "github.com/feast-dev/feast/go/internal/feast/onlineserving" "github.com/feast-dev/feast/go/internal/feast/registry" + "github.com/feast-dev/feast/go/internal/feast/server" + "github.com/feast-dev/feast/go/internal/feast/server/logging" "github.com/feast-dev/feast/go/internal/feast/transformation" + "github.com/feast-dev/feast/go/protos/feast/serving" prototypes "github.com/feast-dev/feast/go/protos/feast/types" "github.com/feast-dev/feast/go/types" ) @@ -44,6 +43,15 @@ type DataTable struct { SchemaPtr uintptr } +// LoggingOptions is a public (embedded) copy of logging.LoggingOptions struct. +// See logging.LoggingOptions for properties description +type LoggingOptions struct { + ChannelCapacity int + EmitTimeout time.Duration + WriteInterval time.Duration + FlushInterval time.Duration +} + func NewOnlineFeatureService(conf *OnlineFeatureServiceConfig, transformationCallback transformation.TransformationCallback) *OnlineFeatureService { repoConfig, err := registry.NewRepoConfigFromJSON(conf.RepoPath, conf.RepoConfig) if err != nil { @@ -214,17 +222,50 @@ func (s *OnlineFeatureService) GetOnlineFeatures( return nil } +// StartGprcServer starts gRPC server with disabled feature logging and blocks the thread func (s *OnlineFeatureService) StartGprcServer(host string, port int) error { - // TODO(oleksii): enable logging - // Disable logging for now + return s.StartGprcServerWithLogging(host, port, nil, LoggingOptions{}) +} + +// StartGprcServerWithLoggingDefaultOpts starts gRPC server with enabled feature logging but default configuration for logging +// Caller of this function must provide Python callback to flush buffered logs +func (s *OnlineFeatureService) StartGprcServerWithLoggingDefaultOpts(host string, port int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback) error { + defaultOpts := LoggingOptions{ + ChannelCapacity: logging.DefaultOptions.ChannelCapacity, + EmitTimeout: logging.DefaultOptions.EmitTimeout, + WriteInterval: logging.DefaultOptions.WriteInterval, + FlushInterval: logging.DefaultOptions.FlushInterval, + } + return s.StartGprcServerWithLogging(host, port, writeLoggedFeaturesCallback, defaultOpts) +} + +// StartGprcServerWithLogging starts gRPC server with enabled feature logging +// Caller of this function must provide Python callback to flush buffered logs as well as logging configuration (loggingOpts) +func (s *OnlineFeatureService) StartGprcServerWithLogging(host string, port int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts LoggingOptions) error { var loggingService *logging.LoggingService = nil + var err error + if writeLoggedFeaturesCallback != nil { + sink, err := logging.NewOfflineStoreSink(writeLoggedFeaturesCallback) + if err != nil { + return err + } + + loggingService, err = logging.NewLoggingService(s.fs, sink, logging.LoggingOptions{ + ChannelCapacity: loggingOpts.ChannelCapacity, + EmitTimeout: loggingOpts.EmitTimeout, + WriteInterval: loggingOpts.WriteInterval, + FlushInterval: loggingOpts.FlushInterval, + }) + if err != nil { + return err + } + } ser := server.NewGrpcServingServiceServer(s.fs, loggingService) log.Printf("Starting a gRPC server on host %s port %d\n", host, port) lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port)) if err != nil { return err } - log.Printf("Listening a gRPC server on host %s port %d\n", host, port) grpcServer := grpc.NewServer() serving.RegisterServingServiceServer(grpcServer, ser) @@ -234,6 +275,10 @@ func (s *OnlineFeatureService) StartGprcServer(host string, port int) error { <-s.grpcStopCh fmt.Println("Stopping the gRPC server...") grpcServer.GracefulStop() + if loggingService != nil { + loggingService.Stop() + } + fmt.Println("gRPC server terminated") }() err = grpcServer.Serve(lis) diff --git a/go/internal/feast/server/logging/filelogsink.go b/go/internal/feast/server/logging/filelogsink.go index 1d9afcd523..c9f2049a04 100644 --- a/go/internal/feast/server/logging/filelogsink.go +++ b/go/internal/feast/server/logging/filelogsink.go @@ -49,7 +49,7 @@ func (s *FileLogSink) Write(record arrow.Record) error { return pqarrow.WriteTable(table, writer, 100, props, arrProps) } -func (s *FileLogSink) Flush() error { +func (s *FileLogSink) Flush(featureServiceName string) error { // files are already flushed during Write return nil } diff --git a/go/internal/feast/server/logging/logger.go b/go/internal/feast/server/logging/logger.go index 346bfdbf61..d7ed1fbe18 100644 --- a/go/internal/feast/server/logging/logger.go +++ b/go/internal/feast/server/logging/logger.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "math/rand" + "strings" "sync" "time" @@ -37,7 +38,7 @@ type LogSink interface { // Flush actually send data to a sink. // We want to control amount to interaction with sink, since it could be a costly operation. // Also, some sinks like BigQuery might have quotes and physically limit amount of write requests per day. - Flush() error + Flush(featureServiceName string) error } type Logger interface { @@ -135,6 +136,10 @@ func (l *LoggerImpl) loggerLoop() (lErr error) { lErr = errors.WithStack(rErr) } }() + + writeTicker := time.NewTicker(l.config.WriteInterval) + flushTicker := time.NewTicker(l.config.FlushInterval) + for { shouldStop := false @@ -144,18 +149,18 @@ func (l *LoggerImpl) loggerLoop() (lErr error) { if err != nil { log.Printf("Log write failed: %+v", err) } - err = l.sink.Flush() + err = l.sink.Flush(l.featureServiceName) if err != nil { log.Printf("Log flush failed: %+v", err) } shouldStop = true - case <-time.After(l.config.WriteInterval): + case <-writeTicker.C: err := l.buffer.writeBatch(l.sink) if err != nil { log.Printf("Log write failed: %+v", err) } - case <-time.After(l.config.FlushInterval): - err := l.sink.Flush() + case <-flushTicker.C: + err := l.sink.Flush(l.featureServiceName) if err != nil { log.Printf("Log flush failed: %+v", err) } @@ -171,6 +176,9 @@ func (l *LoggerImpl) loggerLoop() (lErr error) { } } + writeTicker.Stop() + flushTicker.Stop() + // Notify all waiters for graceful stop l.cond.L.Lock() l.isStopped = true @@ -225,7 +233,11 @@ func (l *LoggerImpl) Log(joinKeyToEntityValues map[string][]*types.Value, featur for idx, featureName := range l.schema.Features { featureIdx, ok := featureNameToVectorIdx[featureName] if !ok { - return errors.Errorf("Missing feature %s in log data", featureName) + featureNameParts := strings.Split(featureName, "__") + featureIdx, ok = featureNameToVectorIdx[featureNameParts[1]] + if !ok { + return errors.Errorf("Missing feature %s in log data", featureName) + } } featureValues[idx] = featureVectors[featureIdx].Values[rowIdx] featureStatuses[idx] = featureVectors[featureIdx].Statuses[rowIdx] @@ -259,7 +271,7 @@ func (l *LoggerImpl) Log(joinKeyToEntityValues map[string][]*types.Value, featur EventTimestamps: eventTimestamps, RequestId: requestId, - LogTimestamp: time.Now(), + LogTimestamp: time.Now().UTC(), } err := l.EmitLog(&newLog) if err != nil { diff --git a/go/internal/feast/server/logging/logger_test.go b/go/internal/feast/server/logging/logger_test.go index 0c8e33ef6f..5625b05a76 100644 --- a/go/internal/feast/server/logging/logger_test.go +++ b/go/internal/feast/server/logging/logger_test.go @@ -28,7 +28,7 @@ func (s *DummySink) Write(rec arrow.Record) error { return nil } -func (s *DummySink) Flush() error { +func (s *DummySink) Flush(featureServiceName string) error { return nil } diff --git a/go/internal/feast/server/logging/memorybuffer.go b/go/internal/feast/server/logging/memorybuffer.go index 80a5e03228..36eb7118cb 100644 --- a/go/internal/feast/server/logging/memorybuffer.go +++ b/go/internal/feast/server/logging/memorybuffer.go @@ -2,7 +2,6 @@ package logging import ( "fmt" - "time" "github.com/apache/arrow/go/v8/arrow" "github.com/apache/arrow/go/v8/arrow/array" @@ -143,7 +142,7 @@ func (b *MemoryBuffer) convertToArrowRecord() (arrow.Record, error) { } logTimestamp := arrow.Timestamp(logRow.LogTimestamp.UnixMicro()) - logDate := arrow.Date32(logRow.LogTimestamp.Truncate(24 * time.Hour).Unix()) + logDate := arrow.Date32FromTime(logRow.LogTimestamp) builder.Field(fieldNameToIdx[LOG_TIMESTAMP_FIELD]).(*array.TimestampBuilder).UnsafeAppend(logTimestamp) builder.Field(fieldNameToIdx[LOG_DATE_FIELD]).(*array.Date32Builder).UnsafeAppend(logDate) diff --git a/go/internal/feast/server/logging/memorybuffer_test.go b/go/internal/feast/server/logging/memorybuffer_test.go index f652f2c99a..59f035799b 100644 --- a/go/internal/feast/server/logging/memorybuffer_test.go +++ b/go/internal/feast/server/logging/memorybuffer_test.go @@ -158,7 +158,7 @@ func TestSerializeToArrowTable(t *testing.T) { // log date today := time.Now().Truncate(24 * time.Hour) builder.Field(8).(*array.Date32Builder).AppendValues( - []arrow.Date32{arrow.Date32(today.Unix()), arrow.Date32(today.Unix())}, []bool{true, true}) + []arrow.Date32{arrow.Date32FromTime(today), arrow.Date32FromTime(today)}, []bool{true, true}) // request id builder.Field(9).(*array.StringBuilder).AppendValues( diff --git a/go/internal/feast/server/logging/offlinestoresink.go b/go/internal/feast/server/logging/offlinestoresink.go new file mode 100644 index 0000000000..ee4c646a9b --- /dev/null +++ b/go/internal/feast/server/logging/offlinestoresink.go @@ -0,0 +1,83 @@ +package logging + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + + "github.com/apache/arrow/go/v8/arrow" + "github.com/apache/arrow/go/v8/arrow/array" + "github.com/apache/arrow/go/v8/parquet" + "github.com/apache/arrow/go/v8/parquet/pqarrow" + "github.com/google/uuid" +) + +type OfflineStoreWriteCallback func(featureServiceName, datasetDir string) string + +type OfflineStoreSink struct { + datasetDir string + writeCallback OfflineStoreWriteCallback +} + +func NewOfflineStoreSink(writeCallback OfflineStoreWriteCallback) (*OfflineStoreSink, error) { + return &OfflineStoreSink{ + datasetDir: "", + writeCallback: writeCallback, + }, nil +} + +func (s *OfflineStoreSink) getOrCreateDatasetDir() (string, error) { + if s.datasetDir != "" { + return s.datasetDir, nil + } + dir, err := ioutil.TempDir("", "*") + if err != nil { + return "", err + } + s.datasetDir = dir + return s.datasetDir, nil +} + +func (s *OfflineStoreSink) cleanCurrentDatasetDir() error { + if s.datasetDir == "" { + return nil + } + datasetDir := s.datasetDir + s.datasetDir = "" + return os.RemoveAll(datasetDir) +} + +func (s *OfflineStoreSink) Write(record arrow.Record) error { + fileName, _ := uuid.NewUUID() + datasetDir, err := s.getOrCreateDatasetDir() + if err != nil { + return err + } + + var writer io.Writer + writer, err = os.Create(filepath.Join(datasetDir, fmt.Sprintf("%s.parquet", fileName.String()))) + if err != nil { + return err + } + table := array.NewTableFromRecords(record.Schema(), []arrow.Record{record}) + + props := parquet.NewWriterProperties(parquet.WithDictionaryDefault(false)) + arrProps := pqarrow.DefaultWriterProps() + return pqarrow.WriteTable(table, writer, 1000, props, arrProps) +} + +func (s *OfflineStoreSink) Flush(featureServiceName string) error { + if s.datasetDir == "" { + return nil + } + + errMsg := s.writeCallback(featureServiceName, s.datasetDir) + if errMsg != "" { + return errors.New(errMsg) + } + + return s.cleanCurrentDatasetDir() +} diff --git a/go/internal/feast/server/logging/service.go b/go/internal/feast/server/logging/service.go index a06698638a..9249ad4f2f 100644 --- a/go/internal/feast/server/logging/service.go +++ b/go/internal/feast/server/logging/service.go @@ -98,3 +98,10 @@ func (s *LoggingService) GetOrCreateLogger(featureService *model.FeatureService) return logger, nil } + +func (s *LoggingService) Stop() { + for _, logger := range s.loggers { + logger.Stop() + logger.WaitUntilStopped() + } +} diff --git a/go/internal/test/go_integration_test_utils.go b/go/internal/test/go_integration_test_utils.go index 6d236a4319..eb727ba1db 100644 --- a/go/internal/test/go_integration_test_utils.go +++ b/go/internal/test/go_integration_test_utils.go @@ -138,7 +138,7 @@ func SetupInitializedRepo(basePath string) error { // var stderr bytes.Buffer // var stdout bytes.Buffer applyCommand.Dir = featureRepoPath - out, err := applyCommand.Output() + out, err := applyCommand.CombinedOutput() if err != nil { log.Println(string(out)) return err @@ -152,7 +152,7 @@ func SetupInitializedRepo(basePath string) error { materializeCommand := exec.Command("feast", "materialize-incremental", formattedTime) materializeCommand.Env = os.Environ() materializeCommand.Dir = featureRepoPath - out, err = materializeCommand.Output() + out, err = materializeCommand.CombinedOutput() if err != nil { log.Println(string(out)) return err diff --git a/sdk/python/feast/embedded_go/online_features_service.py b/sdk/python/feast/embedded_go/online_features_service.py index 613838b9a9..48e31766cb 100644 --- a/sdk/python/feast/embedded_go/online_features_service.py +++ b/sdk/python/feast/embedded_go/online_features_service.py @@ -1,4 +1,5 @@ from functools import partial +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import pyarrow as pa @@ -17,7 +18,12 @@ from feast.types import from_value_type from feast.value_type import ValueType -from .lib.embedded import DataTable, NewOnlineFeatureService, OnlineFeatureServiceConfig +from .lib.embedded import ( + DataTable, + LoggingOptions, + NewOnlineFeatureService, + OnlineFeatureServiceConfig, +) from .lib.go import Slice_string from .type_map import FEAST_TYPE_TO_ARROW_TYPE, arrow_array_to_array_of_proto @@ -31,6 +37,7 @@ def __init__( ): # keep callback in self to prevent it from GC self._transformation_callback = partial(transformation_callback, feature_store) + self._logging_callback = partial(logging_callback, feature_store) self._service = NewOnlineFeatureService( OnlineFeatureServiceConfig( @@ -132,8 +139,24 @@ def get_online_features( resp = record_batch_to_online_response(record_batch) return OnlineResponse(resp) - def start_grpc_server(self, host: str, port: int): - self._service.StartGprcServer(host, port) + def start_grpc_server( + self, + host: str, + port: int, + enable_logging: bool = True, + logging_options: Optional[LoggingOptions] = None, + ): + if enable_logging: + if logging_options: + self._service.StartGprcServerWithLogging( + host, port, self._logging_callback, logging_options + ) + else: + self._service.StartGprcServerWithLoggingDefaultOpts( + host, port, self._logging_callback + ) + else: + self._service.StartGprcServer(host, port) def stop_grpc_server(self): self._service.Stop() @@ -182,6 +205,18 @@ def transformation_callback( return output_record.num_rows +def logging_callback( + fs: "FeatureStore", feature_service_name: str, dataset_dir: str, +) -> bytes: + feature_service = fs.get_feature_service(feature_service_name, allow_cache=True) + try: + fs.write_logged_features(logs=Path(dataset_dir), source=feature_service) + except Exception as exc: + return repr(exc).encode() + + return "".encode() # no error + + def allocate_schema_and_array(): c_schema = ffi.new("struct ArrowSchema*") ptr_schema = int(ffi.cast("uintptr_t", c_schema)) diff --git a/sdk/python/feast/feature_logging.py b/sdk/python/feast/feature_logging.py index acc965ac44..70fab930bb 100644 --- a/sdk/python/feast/feature_logging.py +++ b/sdk/python/feast/feature_logging.py @@ -108,12 +108,12 @@ class _DestinationRegistry(type): def __new__(cls, name, bases, dct): kls = type.__new__(cls, name, bases, dct) - if dct.get("_proto_attr_name"): - cls.classes_by_proto_attr_name[dct["_proto_attr_name"]] = kls + if dct.get("_proto_kind"): + cls.classes_by_proto_attr_name[dct["_proto_kind"]] = kls return kls -class LoggingDestination: +class LoggingDestination(metaclass=_DestinationRegistry): """ Logging destination contains details about where exactly logs should be written inside an offline store. It is implementation specific - each offline store must implement LoggingDestination subclass. diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 13c73612f0..4b015e8ab8 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1988,10 +1988,16 @@ def serve_transformations(self, port: int) -> None: def _teardown_go_server(self): self._go_server = None - def write_logged_features(self, logs: pa.Table, source: Union[FeatureService]): + def write_logged_features( + self, logs: Union[pa.Table, Path], source: Union[FeatureService] + ): """ Write logs produced by a source (currently only feature service is supported as a source) to an offline store. + + Args: + logs: Arrow Table or path to parquet dataset directory on disk + source: Object that produces logs """ if not isinstance(source, FeatureService): raise ValueError("Only feature service is currently supported as a source") @@ -2000,6 +2006,8 @@ def write_logged_features(self, logs: pa.Table, source: Union[FeatureService]): source.logging_config is not None ), "Feature service must be configured with logging config in order to use this functionality" + assert isinstance(logs, (pa.Table, Path)) + self._get_provider().write_feature_service_logs( feature_service=source, logs=logs, diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index ee024d4d40..e9d8bdccbf 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -2,6 +2,7 @@ import tempfile import uuid from datetime import date, datetime, timedelta +from pathlib import Path from typing import ( Callable, ContextManager, @@ -258,7 +259,7 @@ def query_generator() -> Iterator[str]: @staticmethod def write_logged_features( config: RepoConfig, - data: pyarrow.Table, + data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, registry: Registry, @@ -280,6 +281,17 @@ def write_logged_features( ), ) + if isinstance(data, Path): + for file in data.iterdir(): + with file.open("rb") as f: + client.load_table_from_file( + file_obj=f, + destination=destination.table, + job_config=job_config, + ) + + return + with tempfile.TemporaryFile() as parquet_temp_file: pyarrow.parquet.write_table(table=data, where=parquet_temp_file) diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index a85cd880b1..2dea5714fa 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -1,4 +1,5 @@ from datetime import datetime +from pathlib import Path from typing import Callable, List, Optional, Tuple, Union import dask.dataframe as dd @@ -375,7 +376,7 @@ def pull_all_from_table_or_query( @staticmethod def write_logged_features( config: RepoConfig, - data: pyarrow.Table, + data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, registry: Registry, @@ -383,6 +384,9 @@ def write_logged_features( destination = logging_config.destination assert isinstance(destination, FileLoggingDestination) + if isinstance(data, Path): + data = pyarrow.parquet.read_table(data) + filesystem, path = FileSource.create_filesystem_and_path( destination.path, destination.s3_endpoint_override, ) diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index ed545ed5ad..27a85046e5 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -14,6 +14,7 @@ import warnings from abc import ABC, abstractmethod from datetime import datetime +from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Union import pandas as pd @@ -246,7 +247,7 @@ def pull_all_from_table_or_query( @staticmethod def write_logged_features( config: RepoConfig, - data: pyarrow.Table, + data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, registry: Registry, @@ -259,7 +260,7 @@ def write_logged_features( Args: config: Repo configuration object - data: Arrow table produced by logging source. + data: Arrow table or path to parquet directory that contains logs dataset. source: Logging source that provides schema and some additional metadata. logging_config: used to determine destination registry: Feast registry diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index fb5759cbbb..74ba83cb00 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -1,6 +1,7 @@ import contextlib import uuid from datetime import datetime +from pathlib import Path from typing import ( Callable, ContextManager, @@ -265,7 +266,7 @@ def query_generator() -> Iterator[str]: @staticmethod def write_logged_features( config: RepoConfig, - data: pyarrow.Table, + data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, registry: Registry, @@ -277,7 +278,10 @@ def write_logged_features( config.offline_store.region ) s3_resource = aws_utils.get_s3_resource(config.offline_store.region) - s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}.parquet" + if isinstance(data, Path): + s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}" + else: + s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}.parquet" aws_utils.upload_arrow_table_to_redshift( table=data, diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 4cf6716c5e..d39acc9f08 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -42,6 +42,7 @@ execute_snowflake_statement, get_snowflake_conn, write_pandas, + write_parquet, ) from feast.registry import Registry from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -280,7 +281,7 @@ def query_generator() -> Iterator[str]: @staticmethod def write_logged_features( config: RepoConfig, - data: pyarrow.Table, + data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, registry: Registry, @@ -289,12 +290,21 @@ def write_logged_features( snowflake_conn = get_snowflake_conn(config.offline_store) - write_pandas( - snowflake_conn, - data.to_pandas(), - table_name=logging_config.destination.table_name, - auto_create_table=True, - ) + if isinstance(data, Path): + write_parquet( + snowflake_conn, + data, + source.get_schema(registry), + table_name=logging_config.destination.table_name, + auto_create_table=True, + ) + else: + write_pandas( + snowflake_conn, + data.to_pandas(), + table_name=logging_config.destination.table_name, + auto_create_table=True, + ) class SnowflakeRetrievalJob(RetrievalJob): diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 6364297b1e..0b6b798fe0 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -221,7 +221,7 @@ def retrieve_saved_dataset( def write_feature_service_logs( self, feature_service: FeatureService, - logs: pyarrow.Table, + logs: Union[pyarrow.Table, str], config: RepoConfig, registry: Registry, ): diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index f8c2a4482f..7754a58319 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -190,7 +190,7 @@ def retrieve_saved_dataset( def write_feature_service_logs( self, feature_service: FeatureService, - logs: pyarrow.Table, + logs: Union[pyarrow.Table, Path], config: RepoConfig, registry: Registry, ): @@ -199,6 +199,8 @@ def write_feature_service_logs( Schema of logs table is being inferred from the provided feature service. Only feature services with configured logging are accepted. + + Logs dataset can be passed as Arrow Table or path to parquet directory. """ ... @@ -206,8 +208,8 @@ def write_feature_service_logs( def retrieve_feature_service_logs( self, feature_service: FeatureService, - from_: datetime, - to: datetime, + start_date: datetime, + end_date: datetime, config: RepoConfig, registry: Registry, ) -> RetrievalJob: diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index d73c484b29..bb75160a87 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -2,7 +2,8 @@ import os import tempfile import uuid -from typing import Any, Dict, Iterator, Optional, Tuple +from pathlib import Path +from typing import Any, Dict, Iterator, Optional, Tuple, Union import pandas as pd import pyarrow @@ -235,7 +236,7 @@ def upload_df_to_redshift( def upload_arrow_table_to_redshift( - table: pyarrow.Table, + table: Union[pyarrow.Table, Path], redshift_data_client, cluster_id: str, database: str, @@ -265,7 +266,7 @@ def upload_arrow_table_to_redshift( iam_role: IAM Role for Redshift to assume during the COPY command. The role must grant permission to read the S3 location. table_name: The name of the new Redshift table where we copy the dataframe - table: The Arrow Table to upload + table: The Arrow Table or Path to parquet dataset to upload schema: (Optionally) client may provide arrow Schema which will be converted into redshift table schema fail_if_exists: fail if table with such name exists or append data to existing table @@ -275,18 +276,35 @@ def upload_arrow_table_to_redshift( if len(table_name) > REDSHIFT_TABLE_NAME_MAX_LENGTH: raise RedshiftTableNameTooLong(table_name) + if isinstance(table, pyarrow.Table) and not schema: + schema = table.schema + + if not schema: + raise ValueError("Schema must be specified when data is passed as a Path") + bucket, key = get_bucket_and_key(s3_path) - schema = schema or table.schema column_query_list = ", ".join( [f"{field.name} {pa_to_redshift_value_type(field.type)}" for field in schema] ) - # Write the PyArrow Table on disk in Parquet format and upload it to S3 - with tempfile.TemporaryFile(suffix=".parquet") as parquet_temp_file: - pq.write_table(table, parquet_temp_file) - parquet_temp_file.seek(0) - s3_resource.Object(bucket, key).put(Body=parquet_temp_file) + uploaded_files = [] + + if isinstance(table, Path): + for file in table.iterdir(): + file_key = os.path.join(key, file.name) + with file.open("rb") as f: + s3_resource.Object(bucket, file_key).put(Body=f) + + uploaded_files.append(file_key) + else: + # Write the PyArrow Table on disk in Parquet format and upload it to S3 + with tempfile.TemporaryFile(suffix=".parquet") as parquet_temp_file: + pq.write_table(table, parquet_temp_file) + parquet_temp_file.seek(0) + s3_resource.Object(bucket, key).put(Body=parquet_temp_file) + + uploaded_files.append(key) copy_query = ( f"COPY {table_name} FROM '{s3_path}' IAM_ROLE '{iam_role}' FORMAT AS PARQUET" @@ -306,7 +324,8 @@ def upload_arrow_table_to_redshift( ) finally: # Clean up S3 temporary data - s3_resource.Object(bucket, key).delete() + for file_pah in uploaded_files: + s3_resource.Object(bucket, file_pah).delete() @contextlib.contextmanager diff --git a/sdk/python/feast/infra/utils/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake_utils.py index a467a9de42..05834ae436 100644 --- a/sdk/python/feast/infra/utils/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake_utils.py @@ -3,10 +3,12 @@ import random import string from logging import getLogger +from pathlib import Path from tempfile import TemporaryDirectory from typing import Any, Dict, Iterator, List, Optional, Tuple, cast import pandas as pd +import pyarrow from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from tenacity import ( @@ -138,6 +140,73 @@ def write_pandas( the passed in DataFrame. The table will not be created if it already exists create_temp_table: Will make the auto-created table as a temporary table """ + + cursor: SnowflakeCursor = conn.cursor() + stage_name = create_temporary_sfc_stage(cursor) + + upload_df(df, cursor, stage_name, chunk_size, parallel, compression) + copy_uploaded_data_to_table( + cursor, + stage_name, + list(df.columns), + table_name, + database, + schema, + compression, + on_error, + quote_identifiers, + auto_create_table, + create_temp_table, + ) + + +def write_parquet( + conn: SnowflakeConnection, + path: Path, + dataset_schema: pyarrow.Schema, + table_name: str, + database: Optional[str] = None, + schema: Optional[str] = None, + compression: str = "gzip", + on_error: str = "abort_statement", + parallel: int = 4, + quote_identifiers: bool = True, + auto_create_table: bool = False, + create_temp_table: bool = False, +): + cursor: SnowflakeCursor = conn.cursor() + stage_name = create_temporary_sfc_stage(cursor) + + columns = [field.name for field in dataset_schema] + upload_local_pq(path, cursor, stage_name, parallel) + copy_uploaded_data_to_table( + cursor, + stage_name, + columns, + table_name, + database, + schema, + compression, + on_error, + quote_identifiers, + auto_create_table, + create_temp_table, + ) + + +def copy_uploaded_data_to_table( + cursor: SnowflakeCursor, + stage_name: str, + columns: List[str], + table_name: str, + database: Optional[str] = None, + schema: Optional[str] = None, + compression: str = "gzip", + on_error: str = "abort_statement", + quote_identifiers: bool = True, + auto_create_table: bool = False, + create_temp_table: bool = False, +): if database is not None and schema is None: raise ProgrammingError( "Schema has to be provided to write_pandas when a database is provided" @@ -163,37 +232,11 @@ def write_pandas( + (schema + "." if schema else "") + (table_name) ) - if chunk_size is None: - chunk_size = len(df) - cursor: SnowflakeCursor = conn.cursor() - stage_name = create_temporary_sfc_stage(cursor) - with TemporaryDirectory() as tmp_folder: - for i, chunk in chunk_helper(df, chunk_size): - chunk_path = os.path.join(tmp_folder, "file{}.txt".format(i)) - # Dump chunk into parquet file - chunk.to_parquet( - chunk_path, - compression=compression, - use_deprecated_int96_timestamps=True, - ) - # Upload parquet file - upload_sql = ( - "PUT /* Python:snowflake.connector.pandas_tools.write_pandas() */ " - "'file://{path}' @\"{stage_name}\" PARALLEL={parallel}" - ).format( - path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), - stage_name=stage_name, - parallel=parallel, - ) - logger.debug(f"uploading files with '{upload_sql}'") - cursor.execute(upload_sql, _is_internal=True) - # Remove chunk file - os.remove(chunk_path) if quote_identifiers: - columns = '"' + '","'.join(list(df.columns)) + '"' + quoted_columns = '"' + '","'.join(columns) + '"' else: - columns = ",".join(list(df.columns)) + quoted_columns = ",".join(columns) if auto_create_table: file_format_name = create_file_format(compression, compression_map, cursor) @@ -209,7 +252,7 @@ def write_pandas( # columns in order quote = '"' if quote_identifiers else "" create_table_columns = ", ".join( - [f"{quote}{c}{quote} {column_type_mapping[c]}" for c in df.columns] + [f"{quote}{c}{quote} {column_type_mapping[c]}" for c in columns] ) create_table_sql = ( f"CREATE {'TEMP ' if create_temp_table else ''}TABLE IF NOT EXISTS {location} " @@ -225,9 +268,9 @@ def write_pandas( # in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly # see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html) if quote_identifiers: - parquet_columns = "$1:" + ",$1:".join(f'"{c}"' for c in df.columns) + parquet_columns = "$1:" + ",$1:".join(f'"{c}"' for c in columns) else: - parquet_columns = "$1:" + ",$1:".join(df.columns) + parquet_columns = "$1:" + ",$1:".join(columns) copy_into_sql = ( "COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ " "({columns}) " @@ -236,7 +279,7 @@ def write_pandas( "PURGE=TRUE ON_ERROR={on_error}" ).format( location=location, - columns=columns, + columns=quoted_columns, parquet_columns=parquet_columns, stage_name=stage_name, compression=compression_map[compression], @@ -250,6 +293,78 @@ def write_pandas( result_cursor.close() +def upload_df( + df: pd.DataFrame, + cursor: SnowflakeCursor, + stage_name: str, + chunk_size: Optional[int] = None, + parallel: int = 4, + compression: str = "gzip", +): + """ + Args: + df: Dataframe we'd like to write back. + cursor: cursor to be used to communicate with Snowflake. + stage_name: stage name in Snowflake connection. + chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once + (Default value = None). + parallel: Number of threads to be used when uploading chunks, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a + better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). + + """ + if chunk_size is None: + chunk_size = len(df) + + with TemporaryDirectory() as tmp_folder: + for i, chunk in chunk_helper(df, chunk_size): + chunk_path = os.path.join(tmp_folder, "file{}.txt".format(i)) + # Dump chunk into parquet file + chunk.to_parquet( + chunk_path, + compression=compression, + use_deprecated_int96_timestamps=True, + ) + # Upload parquet file + upload_sql = ( + "PUT /* Python:feast.infra.utils.snowflake_utils.upload_df() */ " + "'file://{path}' @\"{stage_name}\" PARALLEL={parallel}" + ).format( + path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), + stage_name=stage_name, + parallel=parallel, + ) + logger.debug(f"uploading files with '{upload_sql}'") + cursor.execute(upload_sql, _is_internal=True) + # Remove chunk file + os.remove(chunk_path) + + +def upload_local_pq( + path: Path, cursor: SnowflakeCursor, stage_name: str, parallel: int = 4, +): + """ + Args: + path: Path to parquet dataset on disk + cursor: cursor to be used to communicate with Snowflake. + stage_name: stage name in Snowflake connection. + parallel: Number of threads to be used when uploading chunks, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + """ + for file in path.iterdir(): + upload_sql = ( + "PUT /* Python:feast.infra.utils.snowflake_utils.upload_local_pq() */ " + "'file://{path}' @\"{stage_name}\" PARALLEL={parallel}" + ).format( + path=str(file).replace("\\", "\\\\").replace("'", "\\'"), + stage_name=stage_name, + parallel=parallel, + ) + logger.debug(f"uploading files with '{upload_sql}'") + cursor.execute(upload_sql, _is_internal=True) + + @retry( wait=wait_exponential(multiplier=1, max=4), retry=retry_if_exception_type(ProgrammingError), diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index bccf7931b5..2d61c36273 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -1,4 +1,5 @@ from datetime import datetime +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import pandas @@ -84,7 +85,7 @@ def retrieve_saved_dataset(self, config: RepoConfig, dataset: SavedDataset): def write_feature_service_logs( self, feature_service: FeatureService, - logs: pyarrow.Table, + logs: Union[pyarrow.Table, Path], config: RepoConfig, registry: Registry, ): @@ -93,8 +94,8 @@ def write_feature_service_logs( def retrieve_feature_service_logs( self, feature_service: FeatureService, - from_: datetime, - to: datetime, + start_date: datetime, + end_date: datetime, config: RepoConfig, registry: Registry, ) -> RetrievalJob: diff --git a/sdk/python/tests/integration/e2e/test_go_feature_server.py b/sdk/python/tests/integration/e2e/test_go_feature_server.py index 325e185ce2..a45a101f48 100644 --- a/sdk/python/tests/integration/e2e/test_go_feature_server.py +++ b/sdk/python/tests/integration/e2e/test_go_feature_server.py @@ -1,14 +1,20 @@ import socket import threading +import time from contextlib import closing +from datetime import datetime from typing import List import grpc +import pandas as pd import pytest +import pytz from feast import FeatureService, ValueType +from feast.embedded_go.lib.embedded import LoggingOptions from feast.embedded_go.online_features_service import EmbeddedOnlineFeatureServer from feast.feast_object import FeastObject +from feast.feature_logging import LoggingConfig from feast.protos.feast.serving.ServingService_pb2 import ( FieldStatus, GetOnlineFeaturesRequest, @@ -38,15 +44,22 @@ LOCAL_REPO_CONFIGS = [ IntegrationTestRepoConfig(online_store=REDIS_CONFIG, go_feature_retrieval=True), ] +LOCAL_REPO_CONFIGS = [ + c + for c in LOCAL_REPO_CONFIGS + if c.offline_store_creator in AVAILABLE_OFFLINE_STORES + and c.online_store in AVAILABLE_ONLINE_STORES +] + +NANOSECOND = 1 +MILLISECOND = 1000_000 * NANOSECOND +SECOND = 1000 * MILLISECOND @pytest.fixture( - params=[ - c - for c in LOCAL_REPO_CONFIGS - if c.offline_store_creator in AVAILABLE_OFFLINE_STORES - and c.online_store in AVAILABLE_ONLINE_STORES - ] + params=LOCAL_REPO_CONFIGS, + ids=[str(c) for c in LOCAL_REPO_CONFIGS], + scope="session", ) def local_environment(request): e = construct_test_environment(request.param) @@ -58,15 +71,25 @@ def cleanup(): return e -@pytest.fixture -def initialized_registry(local_environment): +@pytest.fixture(scope="session") +def test_data(local_environment): + return construct_universal_test_data(local_environment) + + +@pytest.fixture(scope="session") +def initialized_registry(local_environment, test_data): fs = local_environment.feature_store - entities, datasets, data_sources = construct_universal_test_data(local_environment) + _, _, data_sources = test_data feature_views = construct_universal_feature_views(data_sources) feature_service = FeatureService( - name="driver_features", features=[feature_views.driver] + name="driver_features", + features=[feature_views.driver], + logging_config=LoggingConfig( + destination=local_environment.data_source_creator.create_logged_features_destination(), + sample_rate=1.0, + ), ) feast_objects: List[FeastObject] = [feature_service] feast_objects.extend(feature_views.values()) @@ -85,7 +108,19 @@ def grpc_server_port(local_environment, initialized_registry): ) port = free_port() - t = threading.Thread(target=embedded.start_grpc_server, args=("127.0.0.1", port)) + t = threading.Thread( + target=embedded.start_grpc_server, + args=("127.0.0.1", port), + kwargs=dict( + enable_logging=True, + logging_options=LoggingOptions( + ChannelCapacity=100, + WriteInterval=100 * MILLISECOND, + FlushInterval=1 * SECOND, + EmitTimeout=10 * MILLISECOND, + ), + ), + ) t.start() wait_retry_backoff( @@ -94,6 +129,8 @@ def grpc_server_port(local_environment, initialized_registry): yield port embedded.stop_grpc_server() + # wait for graceful stop + time.sleep(2) @pytest.fixture @@ -128,6 +165,69 @@ def test_go_grpc_server(grpc_client): assert all([s == FieldStatus.PRESENT for s in vector.statuses]) +@pytest.mark.integration +@pytest.mark.universal +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) +def test_feature_logging(grpc_client, local_environment, test_data, full_feature_names): + fs = local_environment.feature_store + feature_service = fs.get_feature_service("driver_features") + log_start_date = datetime.now().astimezone(pytz.UTC) + driver_ids = list(range(5001, 5011)) + + for driver_id in driver_ids: + # send each driver id in separate request + grpc_client.GetOnlineFeatures( + GetOnlineFeaturesRequest( + feature_service="driver_features", + entities={ + "driver_id": RepeatedValue( + val=python_values_to_proto_values( + [driver_id], feature_type=ValueType.INT64 + ) + ) + }, + full_feature_names=full_feature_names, + ) + ) + # with some pause + time.sleep(0.1) + + _, datasets, _ = test_data + latest_rows = get_latest_rows(datasets.driver_df, "driver_id", driver_ids) + features = [ + feature.name + for proj in feature_service.feature_view_projections + for feature in proj.features + ] + expected_logs = generate_expected_logs( + latest_rows, "driver_stats", features, ["driver_id"], "event_timestamp" + ) + + def retrieve(): + retrieval_job = fs._get_provider().retrieve_feature_service_logs( + feature_service=feature_service, + start_date=log_start_date, + end_date=datetime.now().astimezone(pytz.UTC), + config=fs.config, + registry=fs._registry, + ) + try: + df = retrieval_job.to_df() + except Exception: + # Table or directory was not created yet + return None, False + + return df, df.shape[0] == len(driver_ids) + + persisted_logs = wait_retry_backoff( + retrieve, timeout_secs=60, timeout_msg="Logs retrieval failed" + ) + + persisted_logs = persisted_logs.sort_values(by="driver_id").reset_index(drop=True) + persisted_logs = persisted_logs[expected_logs.columns] + pd.testing.assert_frame_equal(expected_logs, persisted_logs, check_dtype=False) + + def free_port(): sock = socket.socket() sock.bind(("", 0)) @@ -137,3 +237,23 @@ def free_port(): def check_port_open(host, port) -> bool: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: return sock.connect_ex((host, port)) == 0 + + +def get_latest_rows(df, join_key, entity_values): + rows = df[df[join_key].isin(entity_values)] + return rows.loc[rows.groupby(join_key)["event_timestamp"].idxmax()] + + +def generate_expected_logs( + df, feature_view_name, features, join_keys, timestamp_column +): + logs = pd.DataFrame() + for join_key in join_keys: + logs[join_key] = df[join_key] + + for feature in features: + logs[f"{feature_view_name}__{feature}"] = df[feature] + logs[f"{feature_view_name}__{feature}__timestamp"] = df[timestamp_column] + logs[f"{feature_view_name}__{feature}__status"] = FieldStatus.PRESENT + + return logs.sort_values(by=join_keys).reset_index(drop=True) diff --git a/sdk/python/tests/integration/offline_store/test_feature_logging.py b/sdk/python/tests/integration/offline_store/test_feature_logging.py index f15eb8a849..6dda2e63a9 100644 --- a/sdk/python/tests/integration/offline_store/test_feature_logging.py +++ b/sdk/python/tests/integration/offline_store/test_feature_logging.py @@ -1,8 +1,13 @@ +import contextlib import datetime +import tempfile import uuid +from pathlib import Path +from typing import Iterator, Union import numpy as np import pandas as pd +import pyarrow import pyarrow as pa import pytest from google.api_core.exceptions import NotFound @@ -27,7 +32,8 @@ @pytest.mark.integration @pytest.mark.universal -def test_feature_service_logging(environment, universal_data_sources): +@pytest.mark.parametrize("pass_as_path", [True, False], ids=lambda v: str(v)) +def test_feature_service_logging(environment, universal_data_sources, pass_as_path): store = environment.feature_store (_, datasets, data_sources) = universal_data_sources @@ -50,21 +56,23 @@ def test_feature_service_logging(environment, universal_data_sources): ), ) - num_rows = logs_df.shape[0] - first_batch = logs_df.iloc[: num_rows // 2, :] - second_batch = logs_df.iloc[num_rows // 2 :, :] - schema = FeatureServiceLoggingSource( feature_service=feature_service, project=store.project ).get_schema(store._registry) - store.write_logged_features( - source=feature_service, logs=pa.Table.from_pandas(first_batch, schema=schema), - ) + num_rows = logs_df.shape[0] + first_batch = pa.Table.from_pandas(logs_df.iloc[: num_rows // 2, :], schema=schema) + second_batch = pa.Table.from_pandas(logs_df.iloc[num_rows // 2 :, :], schema=schema) - store.write_logged_features( - source=feature_service, logs=pa.Table.from_pandas(second_batch, schema=schema), - ) + with to_logs_dataset(first_batch, pass_as_path) as logs: + store.write_logged_features( + source=feature_service, logs=logs, + ) + + with to_logs_dataset(second_batch, pass_as_path) as logs: + store.write_logged_features( + source=feature_service, logs=logs, + ) expected_columns = list(set(logs_df.columns) - {LOG_DATE_FIELD}) def retrieve(): @@ -122,3 +130,16 @@ def prepare_logs(datasets: UniversalDatasets) -> pd.DataFrame: logs_df[f"{view}__{feature}__status"] = FieldStatus.PRESENT return logs_df + + +@contextlib.contextmanager +def to_logs_dataset( + table: pyarrow.Table, pass_as_path: bool +) -> Iterator[Union[pyarrow.Table, Path]]: + if not pass_as_path: + yield table + return + + with tempfile.TemporaryDirectory() as temp_dir: + pyarrow.parquet.write_to_dataset(table, root_path=temp_dir) + yield Path(temp_dir)