diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index c2ebd388a2365..c97b10ee63b18 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -192,6 +192,20 @@ private[spark] object TestUtils { assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } + /** + * Asserts that exception message contains the message. Please note this checks all + * exceptions in the tree. + */ + def assertExceptionMsg(exception: Throwable, msg: String): Unit = { + var e = exception + var contains = e.getMessage.contains(msg) + while (e.getCause != null && !contains) { + e = e.getCause + contains = e.getMessage.contains(msg) + } + assert(contains, s"Exception tree doesn't contain the expected message: $msg") + } + /** * Test if a command is available. */ diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index cd065ea01dda4..9bbb115bcdda5 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -70,6 +70,8 @@ url: sql-migration-guide-upgrade.html - text: Compatibility with Apache Hive url: sql-migration-guide-hive-compatibility.html + - text: SQL Reserved/Non-Reserved Keywords + url: sql-reserved-and-non-reserved-keywords.html - text: Reference url: sql-reference.html subitems: diff --git a/docs/sql-reserved-and-non-reserved-keywords.md b/docs/sql-reserved-and-non-reserved-keywords.md new file mode 100644 index 0000000000000..53eb9988f6c88 --- /dev/null +++ b/docs/sql-reserved-and-non-reserved-keywords.md @@ -0,0 +1,575 @@ +--- +layout: global +title: SQL Reserved/Non-Reserved Keywords +displayTitle: SQL Reserved/Non-Reserved Keywords +--- + +In Spark SQL, there are 2 kinds of keywords: non-reserved and reserved. Non-reserved keywords have a +special meaning only in particular contexts and can be used as identifiers (e.g., table names, view names, +column names, column aliases, table aliases) in other contexts. Reserved keywords can't be used as +table alias, but can be used as other identifiers. + +The list of reserved and non-reserved keywords can change according to the config +`spark.sql.parser.ansi.enabled`, which is false by default. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
KeywordSpark SQLSQL-2011
ANSI modedefault mode
ABSnon-reservednon-reservedreserved
ABSOLUTEnon-reservednon-reservednon-reserved
ACOSnon-reservednon-reservednon-reserved
ACTIONnon-reservednon-reservednon-reserved
ADDnon-reservednon-reservednon-reserved
AFTERnon-reservednon-reservednon-reserved
ALLreservednon-reservedreserved
ALLOCATEnon-reservednon-reservedreserved
ALTERnon-reservednon-reservedreserved
ANALYZEnon-reservednon-reservednon-reserved
ANDreservednon-reservedreserved
ANTIreservedreservednon-reserved
ANYreservednon-reservedreserved
AREnon-reservednon-reservedreserved
ARCHIVEnon-reservednon-reservednon-reserved
ARRAYnon-reservednon-reservedreserved
ARRAY_AGGnon-reservednon-reservedreserved
ARRAY_MAX_CARDINALITYnon-reservednon-reservedreserved
ASreservednon-reservedreserved
ASCnon-reservednon-reservednon-reserved
ASENSITIVEnon-reservednon-reservedreserved
ASINnon-reservednon-reservedreserved
ASSERTIONnon-reservednon-reservednon-reserved
ASYMMETRICnon-reservednon-reservedreserved
ATnon-reservednon-reservedreserved
ATANnon-reservednon-reservednon-reserved
ATOMICnon-reservednon-reservedreserved
AUTHORIZATIONreservednon-reservedreserved
AVGnon-reservednon-reservedreserved
BEFOREnon-reservednon-reservednon-reserved
BEGINnon-reservednon-reservedreserved
BEGIN_FRAMEnon-reservednon-reservedreserved
BEGIN_PARTITIONnon-reservednon-reservedreserved
BETWEENnon-reservednon-reservedreserved
BIGINTnon-reservednon-reservedreserved
BINARYnon-reservednon-reservedreserved
BITnon-reservednon-reservednon-reserved
BIT_LENGTHnon-reservednon-reservednon-reserved
BLOBnon-reservednon-reservedreserved
BOOLEANnon-reservednon-reservedreserved
BOTHreservednon-reservedreserved
BREADTHnon-reservednon-reservednon-reserved
BUCKETnon-reservednon-reservednon-reserved
BUCKETSnon-reservednon-reservednon-reserved
BYnon-reservednon-reservedreserved
CACHEnon-reservednon-reservednon-reserved
CALLnon-reservednon-reservedreserved
CALLEDnon-reservednon-reservedreserved
CARDINALITYnon-reservednon-reservedreserved
CASCADEnon-reservednon-reservedreserved
CASCADEDnon-reservednon-reservedreserved
CASEreservednon-reservedreserved
CASTreservednon-reservedreserved
CATALOGnon-reservednon-reservednon-reserved
CEILnon-reservednon-reservedreserved
CEILINGnon-reservednon-reservedreserved
CHANGEnon-reservednon-reservednon-reserved
CHARnon-reservednon-reservedreserved
CHAR_LENGTHnon-reservednon-reservedreserved
CHARACTERnon-reservednon-reservedreserved
CHARACTER_LENGTHnon-reservednon-reservedreserved
CHECKreservednon-reservedreserved
CLASSIFIERnon-reservednon-reservednon-reserved
CLEARnon-reservednon-reservednon-reserved
CLOBnon-reservednon-reservedreserved
CLOSEnon-reservednon-reservedreserved
CLUSTERnon-reservednon-reservednon-reserved
CLUSTEREDnon-reservednon-reservednon-reserved
COALESCEnon-reservednon-reservedreserved
CODEGENnon-reservednon-reservednon-reserved
COLLATEreservednon-reservedreserved
COLLATIONnon-reservednon-reservednon-reserved
COLLECTnon-reservednon-reservedreserved
COLLECTIONnon-reservednon-reservednon-reserved
COLUMNreservednon-reservedreserved
COLUMNSnon-reservednon-reservednon-reserved
COMMENTnon-reservednon-reservednon-reserved
COMMITnon-reservednon-reservedreserved
COMPACTnon-reservednon-reservednon-reserved
COMPACTIONSnon-reservednon-reservednon-reserved
COMPUTEnon-reservednon-reservednon-reserved
CONCATENATEnon-reservednon-reservednon-reserved
CONDITIONnon-reservednon-reservedreserved
CONNECTnon-reservednon-reservednon-reserved
CONNECTIONnon-reservednon-reservednon-reserved
CONSTRAINTreservednon-reservedreserved
CONSTRAINTSnon-reservednon-reservednon-reserved
CONSTRUCTORnon-reservednon-reservednon-reserved
CONTAINSnon-reservednon-reservednon-reserved
CONTINUEnon-reservednon-reservednon-reserved
CONVERTnon-reservednon-reservedreserved
COPYnon-reservednon-reservednon-reserved
CORRnon-reservednon-reservedreserved
CORRESPONDINGnon-reservednon-reservedreserved
COSnon-reservednon-reservednon-reserved
COSHnon-reservednon-reservednon-reserved
COSTnon-reservednon-reservednon-reserved
COUNTnon-reservednon-reservedreserved
COVAR_POPnon-reservednon-reservedreserved
COVAR_SAMPnon-reservednon-reservedreserved
CREATEreservednon-reservedreserved
CROSSreservedreservedreserved
CUBEnon-reservednon-reservedreserved
CUME_DISTnon-reservednon-reservedreserved
CURRENTnon-reservednon-reservedreserved
CURRENT_CATALOGnon-reservednon-reservedreserved
CURRENT_DATEreservednon-reservedreserved
CURRENT_DEFAULT_TRANSFORM_GROUPnon-reservednon-reservedreserved
CURRENT_PATHnon-reservednon-reservedreserved
CURRENT_ROLEnon-reservednon-reservedreserved
CURRENT_ROWnon-reservednon-reservedreserved
CURRENT_SCHEMAnon-reservednon-reservedreserved
CURRENT_TIMEreservednon-reservedreserved
CURRENT_TIMESTAMPreservednon-reservedreserved
CURRENT_TRANSFORM_GROUP_FOR_TYPEnon-reservednon-reservedreserved
CURRENT_USERreservednon-reservedreserved
CURSORnon-reservednon-reservedreserved
CYCLEnon-reservednon-reservedreserved
DATAnon-reservednon-reservednon-reserved
DATABASEnon-reservednon-reservednon-reserved
DATABASESnon-reservednon-reservednon-reserved
DATEnon-reservednon-reservedreserved
DAYnon-reservednon-reservedreserved
DBPROPERTIESnon-reservednon-reservednon-reserved
DEALLOCATEnon-reservednon-reservedreserved
DECnon-reservednon-reservedreserved
DECFLOATnon-reservednon-reservednon-reserved
DECIMALnon-reservednon-reservedreserved
DECLAREnon-reservednon-reservedreserved
DEFAULTnon-reservednon-reservedreserved
DEFERRABLEnon-reservednon-reservednon-reserved
DEFERREDnon-reservednon-reservednon-reserved
DEFINEnon-reservednon-reservednon-reserved
DEFINEDnon-reservednon-reservednon-reserved
DELETEnon-reservednon-reservedreserved
DELIMITEDnon-reservednon-reservednon-reserved
DENSE_RANKnon-reservednon-reservedreserved
DEPTHnon-reservednon-reservednon-reserved
DEREFnon-reservednon-reservedreserved
DESCnon-reservednon-reservednon-reserved
DESCRIBEnon-reservednon-reservedreserved
QUERYnon-reservednon-reservednon-reserved
DESCRIPTORnon-reservednon-reservednon-reserved
DETERMINISTICnon-reservednon-reservedreserved
DFSnon-reservednon-reservednon-reserved
DIAGNOSTICSnon-reservednon-reservednon-reserved
DIRECTORIESnon-reservednon-reservednon-reserved
DIRECTORYnon-reservednon-reservednon-reserved
DISCONNECTnon-reservednon-reservedreserved
DISTINCTreservednon-reservedreserved
DISTRIBUTEnon-reservednon-reservednon-reserved
DIVnon-reservednon-reservednon-reserved
DOnon-reservednon-reservedreserved
DOMAINnon-reservednon-reservednon-reserved
DOUBLEnon-reservednon-reservedreserved
DROPnon-reservednon-reservedreserved
DYNAMICnon-reservednon-reservedreserved
EACHnon-reservednon-reservedreserved
ELEMENTnon-reservednon-reservedreserved
ELSEreservednon-reservedreserved
ELSEIFnon-reservednon-reservedreserved
EMPTYnon-reservednon-reservednon-reserved
ENDreservednon-reservedreserved
END_FRAMEnon-reservednon-reservedreserved
END_PARTITIONnon-reservednon-reservedreserved
EQUALSnon-reservednon-reservednon-reserved
ESCAPEnon-reservednon-reservedreserved
ESCAPEDnon-reservednon-reservednon-reserved
EVERYnon-reservednon-reservedreserved
EXCEPTreservedreservedreserved
EXCEPTIONnon-reservednon-reservednon-reserved
EXCHANGEnon-reservednon-reservednon-reserved
EXECnon-reservednon-reservedreserved
EXECUTEnon-reservednon-reservedreserved
EXISTSnon-reservednon-reservedreserved
EXITnon-reservednon-reservednon-reserved
EXPnon-reservednon-reservednon-reserved
EXPLAINnon-reservednon-reservednon-reserved
EXPORTnon-reservednon-reservednon-reserved
EXTENDEDnon-reservednon-reservednon-reserved
EXTERNALnon-reservednon-reservedreserved
EXTRACTnon-reservednon-reservedreserved
FALSEreservednon-reservedreserved
FETCHreservednon-reservedreserved
FIELDSnon-reservednon-reservednon-reserved
FILEFORMATnon-reservednon-reservednon-reserved
FILTERnon-reservednon-reservedreserved
FIRSTnon-reservednon-reservednon-reserved
FIRST_VALUEnon-reservednon-reservedreserved
FLOATnon-reservednon-reservedreserved
FOLLOWINGnon-reservednon-reservednon-reserved
FORreservednon-reservedreserved
FOREIGNreservednon-reservedreserved
FORMATnon-reservednon-reservednon-reserved
FORMATTEDnon-reservednon-reservednon-reserved
FOUNDnon-reservednon-reservednon-reserved
FRAME_ROWnon-reservednon-reservedreserved
FREEnon-reservednon-reservedreserved
FROMreservednon-reservedreserved
FULLreservedreservedreserved
FUNCTIONnon-reservednon-reservedreserved
FUNCTIONSnon-reservednon-reservednon-reserved
FUSIONnon-reservednon-reservednon-reserved
GENERALnon-reservednon-reservednon-reserved
GETnon-reservednon-reservedreserved
GLOBALnon-reservednon-reservedreserved
GOnon-reservednon-reservednon-reserved
GOTOnon-reservednon-reservednon-reserved
GRANTreservednon-reservedreserved
GROUPreservednon-reservedreserved
GROUPINGnon-reservednon-reservedreserved
GROUPSnon-reservednon-reservedreserved
HANDLERnon-reservednon-reservedreserved
HAVINGreservednon-reservedreserved
HOLDnon-reservednon-reservedreserved
HOURnon-reservednon-reservedreserved
IDENTITYnon-reservednon-reservedreserved
IFnon-reservednon-reservedreserved
IGNOREnon-reservednon-reservednon-reserved
IMMEDIATEnon-reservednon-reservednon-reserved
IMPORTnon-reservednon-reservednon-reserved
INreservednon-reservedreserved
INDICATORnon-reservednon-reservedreserved
INDEXnon-reservednon-reservednon-reserved
INDEXESnon-reservednon-reservednon-reserved
INITIALnon-reservednon-reservednon-reserved
INITIALLYnon-reservednon-reservednon-reserved
INNERreservedreservedreserved
INOUTnon-reservednon-reservedreserved
INPATHnon-reservednon-reservednon-reserved
INPUTnon-reservednon-reservednon-reserved
INPUTFORMATnon-reservednon-reservednon-reserved
INSENSITIVEnon-reservednon-reservedreserved
INSERTnon-reservednon-reservedreserved
INTnon-reservednon-reservedreserved
INTEGERnon-reservednon-reservedreserved
INTERSECTreservedreservedreserved
INTERSECTIONnon-reservednon-reservedreserved
INTERVALnon-reservednon-reservedreserved
INTOreservednon-reservedreserved
ISreservednon-reservedreserved
ISOLATIONnon-reservednon-reservednon-reserved
ITEMSnon-reservednon-reservednon-reserved
ITERATEnon-reservednon-reservedreserved
JOINreservedreservedreserved
JSON_ARRAYnon-reservednon-reservednon-reserved
JSON_ARRAYAGGnon-reservednon-reservednon-reserved
JSON_EXISTSnon-reservednon-reservednon-reserved
JSON_OBJECTnon-reservednon-reservednon-reserved
JSON_OBJECTAGGnon-reservednon-reservednon-reserved
JSON_QUERYnon-reservednon-reservednon-reserved
JSON_TABLEnon-reservednon-reservednon-reserved
JSON_TABLE_PRIMITIVEnon-reservednon-reservednon-reserved
JSON_VALUEnon-reservednon-reservednon-reserved
KEYnon-reservednon-reservednon-reserved
KEYSnon-reservednon-reservednon-reserved
LAGnon-reservednon-reservednon-reserved
LANGUAGEnon-reservednon-reservedreserved
LARGEnon-reservednon-reservedreserved
LASTnon-reservednon-reservednon-reserved
LAST_VALUEnon-reservednon-reservedreserved
LATERALnon-reservednon-reservedreserved
LAZYnon-reservednon-reservednon-reserved
LEADnon-reservednon-reservedreserved
LEADINGreservednon-reservedreserved
LEAVEnon-reservednon-reservedreserved
LEFTreservedreservedreserved
LEVELnon-reservednon-reservednon-reserved
LIKEnon-reservednon-reservedreserved
LIKE_REGEXnon-reservednon-reservedreserved
LIMITnon-reservednon-reservednon-reserved
LINESnon-reservednon-reservednon-reserved
LISTnon-reservednon-reservednon-reserved
LISTAGGnon-reservednon-reservednon-reserved
LNnon-reservednon-reservedreserved
LOADnon-reservednon-reservednon-reserved
LOCALnon-reservednon-reservedreserved
LOCALTIMEnon-reservednon-reservedreserved
LOCALTIMESTAMPnon-reservednon-reservedreserved
LOCATIONnon-reservednon-reservednon-reserved
LOCATORnon-reservednon-reservednon-reserved
LOCKnon-reservednon-reservednon-reserved
LOCKSnon-reservednon-reservednon-reserved
LOGnon-reservednon-reservednon-reserved
LOG10non-reservednon-reservednon-reserved
LOGICALnon-reservednon-reservednon-reserved
LOOPnon-reservednon-reservedreserved
LOWERnon-reservednon-reservedreserved
MACROnon-reservednon-reservednon-reserved
MAPnon-reservednon-reservednon-reserved
MATCHnon-reservednon-reservedreserved
MATCH_NUMBERnon-reservednon-reservednon-reserved
MATCH_RECOGNIZEnon-reservednon-reservednon-reserved
MATCHESnon-reservednon-reservednon-reserved
MAXnon-reservednon-reservedreserved
MEMBERnon-reservednon-reservedreserved
MERGEnon-reservednon-reservedreserved
METHODnon-reservednon-reservedreserved
MINnon-reservednon-reservedreserved
MINUSreservedreservednon-reserved
MINUTEnon-reservednon-reservedreserved
MODnon-reservednon-reservedreserved
MODIFIESnon-reservednon-reservedreserved
MODULEnon-reservednon-reservedreserved
MONTHnon-reservednon-reservedreserved
MSCKnon-reservednon-reservednon-reserved
MULTISETnon-reservednon-reservedreserved
NAMESnon-reservednon-reservednon-reserved
NATIONALnon-reservednon-reservedreserved
NATURALreservedreservedreserved
NCHARnon-reservednon-reservedreserved
NCLOBnon-reservednon-reservedreserved
NEWnon-reservednon-reservedreserved
NEXTnon-reservednon-reservednon-reserved
NOnon-reservednon-reservedreserved
NONEnon-reservednon-reservedreserved
NORMALIZEnon-reservednon-reservedreserved
NOTreservednon-reservedreserved
NTH_VALUEnon-reservednon-reservedreserved
NTILEnon-reservednon-reservedreserved
NULLreservednon-reservedreserved
NULLSnon-reservednon-reservednon-reserved
NULLIFnon-reservednon-reservedreserved
NUMERICnon-reservednon-reservedreserved
OBJECTnon-reservednon-reservednon-reserved
OCCURRENCES_REGEXnon-reservednon-reservednon-reserved
OCTET_LENGTHnon-reservednon-reservedreserved
OFnon-reservednon-reservedreserved
OFFSETnon-reservednon-reservedreserved
OLDnon-reservednon-reservedreserved
OMITnon-reservednon-reservednon-reserved
ONreservedreservedreserved
ONEnon-reservednon-reservednon-reserved
ONLYreservednon-reservedreserved
OPENnon-reservednon-reservedreserved
OPTIONnon-reservednon-reservednon-reserved
OPTIONSnon-reservednon-reservednon-reserved
ORreservednon-reservedreserved
ORDERreservednon-reservedreserved
ORDINALITYnon-reservednon-reservednon-reserved
OUTnon-reservednon-reservedreserved
OUTERreservednon-reservedreserved
OUTPUTnon-reservednon-reservednon-reserved
OUTPUTFORMATnon-reservednon-reservednon-reserved
OVERnon-reservednon-reservednon-reserved
OVERLAPSreservednon-reservedreserved
OVERLAYnon-reservednon-reservedreserved
OVERWRITEnon-reservednon-reservednon-reserved
PADnon-reservednon-reservednon-reserved
PARAMETERnon-reservednon-reservedreserved
PARTIALnon-reservednon-reservednon-reserved
PARTITIONnon-reservednon-reservedreserved
PARTITIONEDnon-reservednon-reservednon-reserved
PARTITIONSnon-reservednon-reservednon-reserved
PATHnon-reservednon-reservednon-reserved
PATTERNnon-reservednon-reservednon-reserved
PERnon-reservednon-reservednon-reserved
PERCENTnon-reservednon-reservedreserved
PERCENT_RANKnon-reservednon-reservedreserved
PERCENTILE_CONTnon-reservednon-reservedreserved
PERCENTILE_DISCnon-reservednon-reservedreserved
PERCENTLITnon-reservednon-reservednon-reserved
PERIODnon-reservednon-reservedreserved
PIVOTnon-reservednon-reservednon-reserved
PORTIONnon-reservednon-reservedreserved
POSITIONnon-reservednon-reservedreserved
POSITION_REGEXnon-reservednon-reservedreserved
POWERnon-reservednon-reservedreserved
PRECEDESnon-reservednon-reservedreserved
PRECEDINGnon-reservednon-reservednon-reserved
PRECISIONnon-reservednon-reservedreserved
PREPAREnon-reservednon-reservedreserved
PRESERVEnon-reservednon-reservednon-reserved
PRIMARYreservednon-reservedreserved
PRINCIPALSnon-reservednon-reservednon-reserved
PRIORnon-reservednon-reservednon-reserved
PRIVILEGESnon-reservednon-reservednon-reserved
PROCEDUREnon-reservednon-reservedreserved
PTFnon-reservednon-reservednon-reserved
PUBLICnon-reservednon-reservednon-reserved
PURGEnon-reservednon-reservednon-reserved
RANGEnon-reservednon-reservedreserved
RANKnon-reservednon-reservedreserved
READnon-reservednon-reservednon-reserved
READSnon-reservednon-reservedreserved
REALnon-reservednon-reservedreserved
RECORDREADERnon-reservednon-reservednon-reserved
RECORDWRITERnon-reservednon-reservednon-reserved
RECURSIVEnon-reservednon-reservedreserved
RECOVERnon-reservednon-reservednon-reserved
REDUCEnon-reservednon-reservednon-reserved
REFnon-reservednon-reservedreserved
REFERENCESreservednon-reservedreserved
REFERENCINGnon-reservednon-reservedreserved
REFRESHnon-reservednon-reservednon-reserved
REGR_AVGXnon-reservednon-reservedreserved
REGR_AVGYnon-reservednon-reservedreserved
REGR_COUNTnon-reservednon-reservedreserved
REGR_INTERCEPTnon-reservednon-reservedreserved
REGR_R2non-reservednon-reservedreserved
REGR_SLOPEnon-reservednon-reservedreserved
REGR_SXXnon-reservednon-reservedreserved
REGR_SXYnon-reservednon-reservedreserved
REGR_SYYnon-reservednon-reservedreserved
RELATIVEnon-reservednon-reservednon-reserved
RELEASEnon-reservednon-reservedreserved
RENAMEnon-reservednon-reservednon-reserved
REPAIRnon-reservednon-reservednon-reserved
REPEATnon-reservednon-reservedreserved
REPLACEnon-reservednon-reservednon-reserved
RESETnon-reservednon-reservednon-reserved
RESIGNALnon-reservednon-reservedreserved
RESTRICTnon-reservednon-reservednon-reserved
RESULTnon-reservednon-reservedreserved
RETURNnon-reservednon-reservedreserved
RETURNSnon-reservednon-reservedreserved
REVOKEnon-reservednon-reservedreserved
RIGHTreservedreservedreserved
RLIKEnon-reservednon-reservednon-reserved
ROLEnon-reservednon-reservednon-reserved
ROLESnon-reservednon-reservednon-reserved
ROLLBACKnon-reservednon-reservedreserved
ROLLUPnon-reservednon-reservedreserved
ROUTINEnon-reservednon-reservednon-reserved
ROWnon-reservednon-reservedreserved
ROW_NUMBERnon-reservednon-reservedreserved
ROWSnon-reservednon-reservedreserved
RUNNINGnon-reservednon-reservednon-reserved
SAVEPOINTnon-reservednon-reservedreserved
SCHEMAnon-reservednon-reservednon-reserved
SCOPEnon-reservednon-reservedreserved
SCROLLnon-reservednon-reservedreserved
SEARCHnon-reservednon-reservedreserved
SECONDnon-reservednon-reservedreserved
SECTIONnon-reservednon-reservednon-reserved
SEEKnon-reservednon-reservednon-reserved
SELECTreservednon-reservedreserved
SEMIreservedreservednon-reserved
SENSITIVEnon-reservednon-reservedreserved
SEPARATEDnon-reservednon-reservednon-reserved
SERDEnon-reservednon-reservednon-reserved
SERDEPROPERTIESnon-reservednon-reservednon-reserved
SESSIONnon-reservednon-reservednon-reserved
SESSION_USERreservednon-reservedreserved
SETnon-reservednon-reservedreserved
SETSnon-reservednon-reservednon-reserved
SHOWnon-reservednon-reservednon-reserved
SIGNALnon-reservednon-reservedreserved
SIMILARnon-reservednon-reservedreserved
SINnon-reservednon-reservednon-reserved
SINHnon-reservednon-reservednon-reserved
SIZEnon-reservednon-reservednon-reserved
SKIPnon-reservednon-reservednon-reserved
SKEWEDnon-reservednon-reservednon-reserved
SMALLINTnon-reservednon-reservedreserved
SOMEreservednon-reservedreserved
SORTnon-reservednon-reservednon-reserved
SORTEDnon-reservednon-reservednon-reserved
SPACEnon-reservednon-reservednon-reserved
SPECIFICnon-reservednon-reservedreserved
SPECIFICTYPEnon-reservednon-reservedreserved
SQLnon-reservednon-reservedreserved
SQLCODEnon-reservednon-reservednon-reserved
SQLERRORnon-reservednon-reservednon-reserved
SQLEXCEPTIONnon-reservednon-reservedreserved
SQLSTATEnon-reservednon-reservedreserved
SQLWARNINGnon-reservednon-reservedreserved
SQRTnon-reservednon-reservedreserved
STARTnon-reservednon-reservedreserved
STATEnon-reservednon-reservednon-reserved
STATICnon-reservednon-reservedreserved
STATISTICSnon-reservednon-reservednon-reserved
STDDEV_POPnon-reservednon-reservedreserved
STDDEV_SAMPnon-reservednon-reservedreserved
STOREDnon-reservednon-reservednon-reserved
STRATIFYnon-reservednon-reservednon-reserved
STRUCTnon-reservednon-reservednon-reserved
SUBMULTISETnon-reservednon-reservedreserved
SUBSETnon-reservednon-reservednon-reserved
SUBSTRINGnon-reservednon-reservedreserved
SUBSTRING_REGEXnon-reservednon-reservedreserved
SUCCEEDSnon-reservednon-reservedreserved
SUMnon-reservednon-reservedreserved
SYMMETRICnon-reservednon-reservedreserved
SYSTEMnon-reservednon-reservedreserved
SYSTEM_TIMEnon-reservednon-reservedreserved
SYSTEM_USERnon-reservednon-reservedreserved
TABLEreservednon-reservedreserved
TABLESnon-reservednon-reservednon-reserved
TABLESAMPLEnon-reservednon-reservedreserved
TANnon-reservednon-reservednon-reserved
TANHnon-reservednon-reservednon-reserved
TBLPROPERTIESnon-reservednon-reservednon-reserved
TEMPORARYnon-reservednon-reservednon-reserved
TERMINATEDnon-reservednon-reservednon-reserved
THENreservednon-reservedreserved
TIMEnon-reservednon-reservedreserved
TIMESTAMPnon-reservednon-reservedreserved
TIMEZONE_HOURnon-reservednon-reservedreserved
TIMEZONE_MINUTEnon-reservednon-reservedreserved
TOreservednon-reservedreserved
TOUCHnon-reservednon-reservednon-reserved
TRAILINGreservednon-reservedreserved
TRANSACTIONnon-reservednon-reservednon-reserved
TRANSACTIONSnon-reservednon-reservednon-reserved
TRANSFORMnon-reservednon-reservednon-reserved
TRANSLATEnon-reservednon-reservedreserved
TRANSLATE_REGEXnon-reservednon-reservedreserved
TRANSLATIONnon-reservednon-reservedreserved
TREATnon-reservednon-reservedreserved
TRIGGERnon-reservednon-reservedreserved
TRIMnon-reservednon-reservedreserved
TRIM_ARRAYnon-reservednon-reservedreserved
TRUEnon-reservednon-reservedreserved
TRUNCATEnon-reservednon-reservedreserved
UESCAPEnon-reservednon-reservedreserved
UNARCHIVEnon-reservednon-reservednon-reserved
UNBOUNDEDnon-reservednon-reservednon-reserved
UNCACHEnon-reservednon-reservednon-reserved
UNDERnon-reservednon-reservednon-reserved
UNDOnon-reservednon-reservedreserved
UNIONreservedreservedreserved
UNIQUEreservednon-reservedreserved
UNKNOWNnon-reservednon-reservedreserved
UNLOCKnon-reservednon-reservednon-reserved
UNNESTnon-reservednon-reservedreserved
UNSETnon-reservednon-reservednon-reserved
UNTILnon-reservednon-reservedreserved
UPDATEnon-reservednon-reservedreserved
UPPERnon-reservednon-reservedreserved
USAGEnon-reservednon-reservednon-reserved
USEnon-reservednon-reservednon-reserved
USERreservednon-reservedreserved
USINGreservedreservedreserved
VALUEnon-reservednon-reservedreserved
VALUESnon-reservednon-reservedreserved
VALUE_OFnon-reservednon-reservedreserved
VAR_POPnon-reservednon-reservedreserved
VAR_SAMPnon-reservednon-reservedreserved
VARBINARYnon-reservednon-reservedreserved
VARCHARnon-reservednon-reservedreserved
VARYINGnon-reservednon-reservedreserved
VERSIONINGnon-reservednon-reservedreserved
VIEWnon-reservednon-reservednon-reserved
WHENreservednon-reservedreserved
WHENEVERnon-reservednon-reservedreserved
WHEREreservednon-reservedreserved
WHILEnon-reservednon-reservedreserved
WIDTH_BUCKETnon-reservednon-reservedreserved
WINDOWnon-reservednon-reservedreserved
WITHreservednon-reservedreserved
WITHINnon-reservednon-reservedreserved
WITHOUTnon-reservednon-reservedreserved
WORKnon-reservednon-reservednon-reserved
WRITEnon-reservednon-reservednon-reserved
YEARnon-reservednon-reservedreserved
ZONEnon-reservednon-reservednon-reserved
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index e76b53dbb4dc3..77db1c3d7d613 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -2980,7 +2980,7 @@ the effect of the change is not well-defined. For all of them: - Changes to the user-defined foreach sink (that is, the `ForeachWriter` code) are allowed, but the semantics of the change depends on the code. -- *Changes in projection / filter / map-like operations**: Some cases are allowed. For example: +- *Changes in projection / filter / map-like operations*: Some cases are allowed. For example: - Addition / deletion of filters is allowed: `sdf.selectExpr("a")` to `sdf.where(...).selectExpr("a").filter(...)`. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala index 0e6171724402e..92686d24e2b8a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala @@ -37,8 +37,7 @@ import org.apache.spark.sql.sources.v2.reader.streaming._ * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceOptions]] params which - * are not Kafka consumer params. + * @param sourceOptions Params which are not Kafka consumer params. * @param metadataPath Path to a directory this reader can use for writing metadata. * @param initialOffsets The Kafka offsets to start reading data at. * @param failOnDataLoss Flag indicating whether reading should fail in data loss @@ -77,7 +76,7 @@ class KafkaContinuousStream( } override def planInputPartitions(start: Offset): Array[InputPartition] = { - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(start) + val oldStartPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 337a51ef7fd80..ae866b48ef74b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -33,9 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.UninterruptibleThread /** @@ -57,7 +57,7 @@ import org.apache.spark.util.UninterruptibleThread private[kafka010] class KafkaMicroBatchStream( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], - options: DataSourceOptions, + options: CaseInsensitiveStringMap, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging { @@ -66,8 +66,7 @@ private[kafka010] class KafkaMicroBatchStream( "kafkaConsumer.pollTimeoutMs", SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) - private val maxOffsetsPerTrigger = - Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger")).map(_.toLong) private val rangeCalculator = KafkaOffsetRangeCalculator(options) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala index 6008794924052..1af8404b89c68 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition -import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -91,8 +91,8 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int private[kafka010] object KafkaOffsetRangeCalculator { - def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = { - val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt) + def apply(options: CaseInsensitiveStringMap): KafkaOffsetRangeCalculator = { + val optionalValue = Option(options.get("minPartitions")).map(_.toInt) new KafkaOffsetRangeCalculator(optionalValue) } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index 8d41c0da2b133..90d70439c5329 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 9238899b0c00c..0b661b7eeaf08 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,11 +31,14 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * The provider class for all Kafka readers and writers. It is designed such that it throws @@ -47,7 +50,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamingWriteSupportProvider with TableProvider with Logging { import KafkaSourceProvider._ @@ -102,8 +104,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } - override def getTable(options: DataSourceOptions): KafkaTable = { - new KafkaTable(strategy(options.asMap().asScala.toMap)) + override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { + new KafkaTable(strategy(options.asScala.toMap)) } /** @@ -180,20 +182,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - import scala.collection.JavaConverters._ - - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - - new KafkaStreamingWriteSupport(topic, producerParams, schema) - } - private def strategy(caseInsensitiveParams: Map[String, String]) = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => @@ -365,23 +353,47 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } class KafkaTable(strategy: => ConsumerStrategy) extends Table - with SupportsMicroBatchRead with SupportsContinuousRead { + with SupportsRead with SupportsWrite { override def name(): String = s"Kafka $strategy" override def schema(): StructType = KafkaOffsetReader.kafkaSchema - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def capabilities(): ju.Set[TableCapability] = { + Set(MICRO_BATCH_READ, CONTINUOUS_READ, STREAMING_WRITE).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new KafkaScan(options) } + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + new WriteBuilder { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def buildForStreaming(): StreamingWrite = { + import scala.collection.JavaConverters._ + + assert(inputSchema != null) + val topic = Option(options.get(TOPIC_OPTION_KEY)).map(_.trim) + val producerParams = kafkaParamsForProducer(options.asScala.toMap) + new KafkaStreamingWrite(topic, producerParams, inputSchema) + } + } + } } - class KafkaScan(options: DataSourceOptions) extends Scan { + class KafkaScan(options: CaseInsensitiveStringMap) extends Scan { override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { - val parameters = options.asMap().asScala.toMap + val parameters = options.asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group @@ -410,7 +422,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } override def toContinuousStream(checkpointLocation: String): ContinuousStream = { - val parameters = options.asMap().asScala.toMap + val parameters = options.asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala similarity index 95% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala index 0d831c3884609..e3101e1572082 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,18 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamingWrite]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamingWriteSupport( +class KafkaStreamingWrite( topic: Option[String], producerParams: ju.Map[String, Object], schema: StructType) - extends StreamingWriteSupport { + extends StreamingWrite { validateQuery(schema.toAttributes, producerParams, topic) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index b21037b1340ce..3c3aeebc48b7f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -22,9 +22,8 @@ import java.util.Locale import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.ByteArraySerializer import org.scalatest.time.SpanSugar._ -import scala.collection.JavaConverters._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{BinaryType, DataType} @@ -227,39 +226,23 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val topic = newTopic() testUtils.createTopic(topic) - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - writer = createKafkaWriter(input.toDF())( + val ex = intercept[AnalysisException] { + /* No topic field or topic option */ + createKafkaWriter(input.toDF())( withSelectExpr = "value as key", "value" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } assert(ex.getMessage .toLowerCase(Locale.ROOT) .contains("topic option required when no 'topic' attribute is present")) - try { + val ex2 = intercept[AnalysisException] { /* No value field */ - writer = createKafkaWriter(input.toDF())( + createKafkaWriter(input.toDF())( withSelectExpr = s"'$topic' as topic", "value as key" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( "required attribute 'value' not found")) } @@ -278,53 +261,30 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val topic = newTopic() testUtils.createTopic(topic) - var writer: StreamingQuery = null - var ex: Exception = null - try { + val ex = intercept[AnalysisException] { /* topic field wrong type */ - writer = createKafkaWriter(input.toDF())( + createKafkaWriter(input.toDF())( withSelectExpr = s"CAST('1' as INT) as topic", "value" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) - try { + val ex2 = intercept[AnalysisException] { /* value field wrong type */ - writer = createKafkaWriter(input.toDF())( + createKafkaWriter(input.toDF())( withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( "value attribute type must be a string or binary")) - try { + val ex3 = intercept[AnalysisException] { /* key field wrong type */ - writer = createKafkaWriter(input.toDF())( + createKafkaWriter(input.toDF())( withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex3.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binary")) } @@ -369,35 +329,22 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("subscribe", inputTopic) .load() - var writer: StreamingQuery = null - var ex: Exception = null - try { - writer = createKafkaWriter( + + val ex = intercept[IllegalArgumentException] { + createKafkaWriter( input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'key.serializer' is not supported")) - } finally { - writer.stop() } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) - try { - writer = createKafkaWriter( + val ex2 = intercept[IllegalArgumentException] { + createKafkaWriter( input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'value.serializer' is not supported")) - } finally { - writer.stop() } + assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) } test("generic - write big data with small producer buffer") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index ad1c2c59d9c8e..9ee8cbfa1bef4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.test.TestSparkSession // Trait to configure StreamTest for kafka continuous execution tests. trait KafkaContinuousTest extends KafkaSourceTest { override val defaultTrigger = Trigger.Continuous(1000) - override val defaultUseV2Sink = true // We need more than the default local[2] to be able to schedule all partitions simultaneously. override protected def createSparkSession = new TestSparkSession( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 8fd5790d753af..d2503a219a16e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -41,10 +41,11 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.util.CaseInsensitiveStringMap abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { @@ -94,7 +95,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { query match { // Make sure no Spark job is running when deleting a topic case Some(m: MicroBatchExecution) => m.processAllAvailable() @@ -114,7 +115,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources: Seq[BaseStreamingSource] = { + val sources: Seq[SparkDataStream] = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source case r: StreamingDataSourceV2Relation if r.stream.isInstanceOf[KafkaMicroBatchStream] || @@ -1118,7 +1119,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val dsOptions = new DataSourceOptions(options.asJava) + val dsOptions = new CaseInsensitiveStringMap(options.asJava) val table = provider.getTable(dsOptions) val stream = table.newScanBuilder(dsOptions).build().toMicroBatchStream(dir.getAbsolutePath) val inputPartitions = stream.planInputPartitions( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala index 2ccf3e291bea7..7ffdaab3e74fb 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -22,13 +22,13 @@ import scala.collection.JavaConverters._ import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.util.CaseInsensitiveStringMap class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { def testWithMinPartitions(name: String, minPartition: Int) (f: KafkaOffsetRangeCalculator => Unit): Unit = { - val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava) + val options = new CaseInsensitiveStringMap(Map("minPartitions" -> minPartition.toString).asJava) test(s"with minPartition = $minPartition: $name") { f(KafkaOffsetRangeCalculator(options)) } @@ -36,7 +36,7 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { test("with no minPartition: N TopicPartitions to N offset ranges") { - val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + val calc = KafkaOffsetRangeCalculator(CaseInsensitiveStringMap.empty()) assert( calc.getRanges( fromOffsets = Map(tp1 -> 1), @@ -64,7 +64,7 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { } test("with no minPartition: empty ranges ignored") { - val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + val calc = KafkaOffsetRangeCalculator(CaseInsensitiveStringMap.empty()) assert( calc.getRanges( fromOffsets = Map(tp1 -> 1, tp2 -> 1), diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2fc9754ecfe1e..4cc467a6d664a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -696,12 +696,14 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { - assert(intercept[SparkException] { + val e1 = intercept[SparkException] { model.transform(dataFrame).first - }.getMessage.contains(msg)) - assert(intercept[StreamingQueryException] { + } + TestUtils.assertExceptionMsg(e1, msg) + val e2 = intercept[StreamingQueryException] { testTransformer[A](dataFrame, model, "prediction") { _ => } - }.getMessage.contains(msg)) + } + TestUtils.assertExceptionMsg(e2, msg) } testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), df("item"))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 514fa7f2e1b8d..0861a3a2d099e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.Suite -import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext} +import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext, TestUtils} import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK import org.apache.spark.ml.{PredictionModel, Transformer} import org.apache.spark.ml.linalg.Vector @@ -129,21 +129,17 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => expectedMessagePart : String, firstResultCol: String) { - def hasExpectedMessage(exception: Throwable): Boolean = - exception.getMessage.contains(expectedMessagePart) || - (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart)) - withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") { val exceptionOnDf = intercept[Throwable] { testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit) } - assert(hasExpectedMessage(exceptionOnDf)) + TestUtils.assertExceptionMsg(exceptionOnDf, expectedMessagePart) } withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") { val exceptionOnStreamData = intercept[Throwable] { testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit) } - assert(hasExpectedMessage(exceptionOnStreamData)) + TestUtils.assertExceptionMsg(exceptionOnStreamData, expectedMessagePart) } } diff --git a/pom.xml b/pom.xml index 5a23ffbc41dd3..9f5699582f5af 100644 --- a/pom.xml +++ b/pom.xml @@ -2381,6 +2381,7 @@ -feature -explaintypes -Yno-adapted-args + -target:jvm-1.8 -Xms1024m diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 76744af2327c5..85127f1fb5c9a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -165,7 +165,84 @@ object MimaExcludes { case ReversedMissingMethodProblem(meth) => !meth.owner.fullName.startsWith("org.apache.spark.sql.sources.v2") case _ => true - } + }, + + // [SPARK-27521][SQL] Move data source v2 to catalyst module + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarBatch"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ArrowColumnVector"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarRow"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarArray"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarMap"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnVector"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThanOrEqual"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringEndsWith"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThanOrEqual$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.In$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Not"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNotNull"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThan"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThanOrEqual"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualNullSafe$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThan$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.In"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.And"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringStartsWith$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualNullSafe"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringEndsWith$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThanOrEqual$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Not$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNull$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThan$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNotNull$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Or"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualTo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThan"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringContains"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Filter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNull"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualTo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.And$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Or$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringStartsWith"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringContains$"), + + // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.inputTypes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullableTypes_="), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.dataType"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.f"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.this"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$1"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3"), + + // [SPARK-11215][ML] Add multiple columns support to StringIndexer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), + + // [SPARK-26616][MLlib] Expose document frequency in IDFModel + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf") ) // Exclude rules for 2.4.x diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fc23b9d99c34a..6cc47ccdbd431 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -186,7 +186,7 @@ def exception(self): je = self._jsq.exception().get() msg = je.toString().split(': ', 1)[1] # Drop the Java StreamingQueryException type info stackTrace = '\n\t at '.join(map(lambda x: x.toString(), je.getStackTrace())) - return StreamingQueryException(msg, stackTrace) + return StreamingQueryException(msg, stackTrace, je.getCause()) else: return None diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py index 4b71759f74a55..1bd81c4411202 100644 --- a/python/pyspark/sql/tests/test_streaming.py +++ b/python/pyspark/sql/tests/test_streaming.py @@ -224,11 +224,19 @@ def test_stream_exception(self): self.fail("bad udf should fail the query") except StreamingQueryException as e: # This is expected - self.assertTrue("ZeroDivisionError" in e.desc) + self._assert_exception_tree_contains_msg(e, "ZeroDivisionError") finally: sq.stop() self.assertTrue(type(sq.exception()) is StreamingQueryException) - self.assertTrue("ZeroDivisionError" in sq.exception().desc) + self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError") + + def _assert_exception_tree_contains_msg(self, exception, msg): + e = exception + contains = msg in e.desc + while e.cause is not None and not contains: + e = e.cause + contains = msg in e.desc + self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg) def test_query_manager_await_termination(self): df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index bdb3a1467f1d8..b80bc4822b6f6 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -19,9 +19,10 @@ class CapturedException(Exception): - def __init__(self, desc, stackTrace): + def __init__(self, desc, stackTrace, cause=None): self.desc = desc self.stackTrace = stackTrace + self.cause = convert_exception(cause) if cause is not None else None def __str__(self): return repr(self.desc) @@ -57,27 +58,41 @@ class QueryExecutionException(CapturedException): """ +class UnknownException(CapturedException): + """ + None of the above exceptions. + """ + + +def convert_exception(e): + s = e.toString() + stackTrace = '\n\t at '.join(map(lambda x: x.toString(), e.getStackTrace())) + c = e.getCause() + if s.startswith('org.apache.spark.sql.AnalysisException: '): + return AnalysisException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('org.apache.spark.sql.catalyst.analysis'): + return AnalysisException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): + return ParseException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '): + return StreamingQueryException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '): + return QueryExecutionException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('java.lang.IllegalArgumentException: '): + return IllegalArgumentException(s.split(': ', 1)[1], stackTrace, c) + return UnknownException(s, stackTrace, c) + + def capture_sql_exception(f): def deco(*a, **kw): try: return f(*a, **kw) except py4j.protocol.Py4JJavaError as e: - s = e.java_exception.toString() - stackTrace = '\n\t at '.join(map(lambda x: x.toString(), - e.java_exception.getStackTrace())) - if s.startswith('org.apache.spark.sql.AnalysisException: '): - raise AnalysisException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.catalyst.analysis'): - raise AnalysisException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): - raise ParseException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '): - raise StreamingQueryException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '): - raise QueryExecutionException(s.split(': ', 1)[1], stackTrace) - if s.startswith('java.lang.IllegalArgumentException: '): - raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) - raise + converted = convert_exception(e.java_exception) + if not isinstance(converted, UnknownException): + raise converted + else: + raise return deco diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 16ecebf159c1f..323032fbfd1f9 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -109,6 +109,10 @@ 2.7.3 jar + + org.apache.arrow + arrow-vector + target/scala-${scala.binary.version}/classes diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index b39681d886c5c..4133331c7fc40 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -44,6 +44,11 @@ grammar SqlBase; return true; } } + + /** + * When true, ANSI SQL parsing mode is enabled. + */ + public boolean ansi = false; } singleStatement @@ -58,6 +63,10 @@ singleTableIdentifier : tableIdentifier EOF ; +singleMultipartIdentifier + : multipartIdentifier EOF + ; + singleFunctionIdentifier : functionIdentifier EOF ; @@ -73,14 +82,14 @@ singleTableSchema statement : query #statementDefault | USE db=identifier #use - | CREATE DATABASE (IF NOT EXISTS)? identifier + | CREATE database (IF NOT EXISTS)? identifier (COMMENT comment=STRING)? locationSpec? (WITH DBPROPERTIES tablePropertyList)? #createDatabase - | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties - | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase + | ALTER database identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties + | DROP database (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider ((OPTIONS options=tablePropertyList) | - (PARTITIONED BY partitionColumnNames=identifierList) | + (PARTITIONED BY partitioning=transformList) | bucketSpec | locationSpec | (COMMENT comment=STRING) | @@ -127,8 +136,8 @@ statement DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions - | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? #dropTable - | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable + | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable + | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? VIEW (IF NOT EXISTS)? tableIdentifier identifierCommentList? (COMMENT STRING)? @@ -158,9 +167,10 @@ statement (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | SHOW CREATE TABLE tableIdentifier #showCreateTable | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction - | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase + | (DESC | DESCRIBE) database EXTENDED? identifier #describeDatabase | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? tableIdentifier partitionSpec? describeColName? #describeTable + | (DESC | DESCRIBE) QUERY? queryToDesc #describeQuery | REFRESH TABLE tableIdentifier #refreshTable | REFRESH (STRING | .*?) #refreshResource | CACHE LAZY? TABLE tableIdentifier @@ -227,7 +237,7 @@ unsupportedHiveNativeCommands ; createTableHeader - : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier ; bucketSpec @@ -250,6 +260,10 @@ query : ctes? queryNoWith ; +queryToDesc + : queryTerm queryOrganization + ; + insertInto : INSERT OVERWRITE TABLE tableIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable | INSERT INTO TABLE? tableIdentifier partitionSpec? #insertIntoTable @@ -269,6 +283,11 @@ partitionVal : identifier (EQ constant)? ; +database + : DATABASE + | SCHEMA + ; + describeFuncName : qualifiedName | STRING @@ -539,6 +558,10 @@ rowFormat (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited ; +multipartIdentifier + : parts+=identifier ('.' parts+=identifier)* + ; + tableIdentifier : (db=identifier '.')? table=identifier ; @@ -555,6 +578,21 @@ namedExpressionSeq : namedExpression (',' namedExpression)* ; +transformList + : '(' transforms+=transform (',' transforms+=transform)* ')' + ; + +transform + : qualifiedName #identityTransform + | transformName=identifier + '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : qualifiedName + | constant + ; + expression : booleanExpression ; @@ -720,14 +758,15 @@ qualifiedName identifier : strictIdentifier - | ANTI | FULL | INNER | LEFT | SEMI | RIGHT | NATURAL | JOIN | CROSS | ON - | UNION | INTERSECT | EXCEPT | SETMINUS + | {ansi}? ansiReserved + | {!ansi}? defaultReserved ; strictIdentifier - : IDENTIFIER #unquotedIdentifier - | quotedIdentifier #quotedIdentifierAlternative - | nonReserved #unquotedIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | {ansi}? ansiNonReserved #unquotedIdentifier + | {!ansi}? nonReserved #unquotedIdentifier ; quotedIdentifier @@ -744,40 +783,67 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; +// NOTE: You must follow a rule below when you add a new ANTLR token in this file: +// - All the ANTLR tokens = UNION(`ansiReserved`, `ansiNonReserved`) = UNION(`defaultReserved`, `nonReserved`) +// +// Let's say you add a new token `NEWTOKEN` and this is not reserved regardless of a `spark.sql.parser.ansi.enabled` +// value. In this case, you must add a token `NEWTOKEN` in both `ansiNonReserved` and `nonReserved`. + +// The list of the reserved keywords when `spark.sql.parser.ansi.enabled` is true. Currently, we only reserve +// the ANSI keywords that almost all the ANSI SQL standards (SQL-92, SQL-99, SQL-2003, SQL-2008, SQL-2011, +// and SQL-2016) and PostgreSQL reserve. +ansiReserved + : ALL | AND | ANTI | ANY | AS | AUTHORIZATION | BOTH | CASE | CAST | CHECK | COLLATE | COLUMN | CONSTRAINT | CREATE + | CROSS | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | CURRENT_USER | DISTINCT | ELSE | END | EXCEPT | FALSE + | FETCH | FOR | FOREIGN | FROM | FULL | GRANT | GROUP | HAVING | IN | INNER | INTERSECT | INTO | JOIN | IS + | LEADING | LEFT | NATURAL | NOT | NULL | ON | ONLY | OR | ORDER | OUTER | OVERLAPS | PRIMARY | REFERENCES | RIGHT + | SELECT | SEMI | SESSION_USER | SETMINUS | SOME | TABLE | THEN | TO | TRAILING | UNION | UNIQUE | USER | USING + | WHEN | WHERE | WITH + ; + + +// The list of the non-reserved keywords when `spark.sql.parser.ansi.enabled` is true. +ansiNonReserved + : ADD | AFTER | ALTER | ANALYZE | ARCHIVE | ARRAY | ASC | AT | BETWEEN | BUCKET | BUCKETS | BY | CACHE | CASCADE + | CHANGE | CLEAR | CLUSTER | CLUSTERED | CODEGEN | COLLECTION | COLUMNS | COMMENT | COMMIT | COMPACT | COMPACTIONS + | COMPUTE | CONCATENATE | COST | CUBE | CURRENT | DATA | DATABASE | DATABASES | DBPROPERTIES | DEFINED | DELETE + | DELIMITED | DESC | DESCRIBE | DFS | DIRECTORIES | DIRECTORY | DISTRIBUTE | DIV | DROP | ESCAPED | EXCHANGE + | EXISTS | EXPLAIN | EXPORT | EXTENDED | EXTERNAL | EXTRACT | FIELDS | FILEFORMAT | FIRST | FOLLOWING | FORMAT + | FORMATTED | FUNCTION | FUNCTIONS | GLOBAL | GROUPING | IF | IGNORE | IMPORT | INDEX | INDEXES | INPATH + | INPUTFORMAT | INSERT | INTERVAL | ITEMS | KEYS | LAST | LATERAL | LAZY | LIKE | LIMIT | LINES | LIST | LOAD + | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO | MAP | MSCK | NO | NULLS | OF | OPTION | OPTIONS | OUT + | OUTPUTFORMAT | OVER | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENT | PERCENTLIT | PIVOT | PRECEDING + | PRINCIPALS | PURGE | QUERY | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFRESH | RENAME | REPAIR | REPLACE + | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS | SCHEMA | SEPARATED | SERDE + | SERDEPROPERTIES | SET | SETS | SHOW | SKEWED | SORT | SORTED | START | STATISTICS | STORED | STRATIFY | STRUCT + | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY | TERMINATED | TOUCH | TRANSACTION | TRANSACTIONS | TRANSFORM + | TRUE | TRUNCATE | UNARCHIVE | UNBOUNDED | UNCACHE | UNLOCK | UNSET | USE | VALUES | VIEW | WINDOW + ; + +defaultReserved + : ANTI | CROSS | EXCEPT | FULL | INNER | INTERSECT | JOIN | LEFT | NATURAL | ON | RIGHT | SEMI | SETMINUS | UNION + | USING + ; + nonReserved - : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES - | ADD - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER - | MAP | ARRAY | STRUCT - | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER - | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED - | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS - | GROUPING | CUBE | ROLLUP - | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | COST - | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF - | SET | RESET - | VIEW | REPLACE - | IF - | POSITION - | EXTRACT - | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE - | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION - | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE - | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT - | DBPROPERTIES | DFS | TRUNCATE | COMPUTE | LIST - | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE - | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH - | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE - | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE - | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE - | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN - | UNBOUNDED | WHEN - | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT - | DIRECTORY - | BOTH | LEADING | TRAILING + : ADD | AFTER | ALL | ALTER | ANALYZE | AND | ANY | ARCHIVE | ARRAY | AS | ASC | AT | AUTHORIZATION | BETWEEN + | BOTH | BUCKET | BUCKETS | BY | CACHE | CASCADE | CASE | CAST | CHANGE | CHECK | CLEAR | CLUSTER | CLUSTERED + | CODEGEN | COLLATE | COLLECTION | COLUMN | COLUMNS | COMMENT | COMMIT | COMPACT | COMPACTIONS | COMPUTE + | CONCATENATE | CONSTRAINT | COST | CREATE | CUBE | CURRENT | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP + | CURRENT_USER | DATA | DATABASE | DATABASES | DBPROPERTIES | DEFINED | DELETE | DELIMITED | DESC | DESCRIBE | DFS + | DIRECTORIES | DIRECTORY | DISTINCT | DISTRIBUTE | DIV | DROP | ELSE | END | ESCAPED | EXCHANGE | EXISTS | EXPLAIN + | EXPORT | EXTENDED | EXTERNAL | EXTRACT | FALSE | FETCH | FIELDS | FILEFORMAT | FIRST | FOLLOWING | FOR | FOREIGN + | FORMAT | FORMATTED | FROM | FUNCTION | FUNCTIONS | GLOBAL | GRANT | GROUP | GROUPING | HAVING | IF | IGNORE + | IMPORT | IN | INDEX | INDEXES | INPATH | INPUTFORMAT | INSERT | INTERVAL | INTO | IS | ITEMS | KEYS | LAST + | LATERAL | LAZY | LEADING | LIKE | LIMIT | LINES | LIST | LOAD | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO + | MAP | MSCK | NO | NOT | NULL | NULLS | OF | ONLY | OPTION | OPTIONS | OR | ORDER | OUT | OUTER | OUTPUTFORMAT + | OVER | OVERLAPS | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENTLIT | PIVOT | POSITION | PRECEDING + | PRIMARY | PRINCIPALS | PURGE | QUERY | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFERENCES | REFRESH + | RENAME | REPAIR | REPLACE | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS + | SELECT | SEPARATED | SERDE | SERDEPROPERTIES | SESSION_USER | SET | SETS | SHOW | SKEWED | SOME | SORT | SORTED + | START | STATISTICS | STORED | STRATIFY | STRUCT | TABLE | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY + | TERMINATED | THEN | TO | TOUCH | TRAILING | TRANSACTION | TRANSACTIONS | TRANSFORM | TRUE | TRUNCATE | UNARCHIVE + | UNBOUNDED | UNCACHE | UNLOCK | UNIQUE | UNSET | USE | USER | VALUES | VIEW | WHEN | WHERE | WINDOW | WITH ; SELECT: 'SELECT'; @@ -850,6 +916,7 @@ WITH: 'WITH'; VALUES: 'VALUES'; CREATE: 'CREATE'; TABLE: 'TABLE'; +QUERY: 'QUERY'; DIRECTORY: 'DIRECTORY'; VIEW: 'VIEW'; REPLACE: 'REPLACE'; @@ -981,7 +1048,8 @@ SORTED: 'SORTED'; PURGE: 'PURGE'; INPUTFORMAT: 'INPUTFORMAT'; OUTPUTFORMAT: 'OUTPUTFORMAT'; -DATABASE: 'DATABASE' | 'SCHEMA'; +SCHEMA: 'SCHEMA'; +DATABASE: 'DATABASE'; DATABASES: 'DATABASES' | 'SCHEMAS'; DFS: 'DFS'; TRUNCATE: 'TRUNCATE'; @@ -1014,6 +1082,24 @@ OPTION: 'OPTION'; ANTI: 'ANTI'; LOCAL: 'LOCAL'; INPATH: 'INPATH'; +AUTHORIZATION: 'AUTHORIZATION'; +CHECK: 'CHECK'; +COLLATE: 'COLLATE'; +CONSTRAINT: 'CONSTRAINT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +FETCH: 'FETCH'; +FOREIGN: 'FOREIGN'; +ONLY: 'ONLY'; +OVERLAPS: 'OVERLAPS'; +PRIMARY: 'PRIMARY'; +REFERENCES: 'REFERENCES'; +SESSION_USER: 'SESSION_USER'; +SOME: 'SOME'; +UNIQUE: 'UNIQUE'; +USER: 'USER'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogPlugin.java new file mode 100644 index 0000000000000..5d4995a05d233 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogPlugin.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A marker interface to provide a catalog implementation for Spark. + *

+ * Implementations can provide catalog functions by implementing additional interfaces for tables, + * views, and functions. + *

+ * Catalog implementations must implement this marker interface to be loaded by + * {@link Catalogs#load(String, SQLConf)}. The loader will instantiate catalog classes using the + * required public no-arg constructor. After creating an instance, it will be configured by calling + * {@link #initialize(String, CaseInsensitiveStringMap)}. + *

+ * Catalog implementations are registered to a name by adding a configuration option to Spark: + * {@code spark.sql.catalog.catalog-name=com.example.YourCatalogClass}. All configuration properties + * in the Spark configuration that share the catalog name prefix, + * {@code spark.sql.catalog.catalog-name.(key)=(value)} will be passed in the case insensitive + * string map of options in initialization with the prefix removed. + * {@code name}, is also passed and is the catalog's name; in this case, "catalog-name". + */ +@Experimental +public interface CatalogPlugin { + /** + * Called to initialize configuration. + *

+ * This method is called once, just after the provider is instantiated. + * + * @param name the name used to identify and load this catalog + * @param options a case-insensitive string map of configuration + */ + void initialize(String name, CaseInsensitiveStringMap options); + + /** + * Called to get this catalog's name. + *

+ * This method is only called after {@link #initialize(String, CaseInsensitiveStringMap)} is + * called to pass the catalog's name. + */ + String name(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java new file mode 100644 index 0000000000000..851a6a9f6d165 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.SparkException; +import org.apache.spark.annotation.Private; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.apache.spark.util.Utils; + +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static scala.collection.JavaConverters.mapAsJavaMapConverter; + +@Private +public class Catalogs { + private Catalogs() { + } + + /** + * Load and configure a catalog by name. + *

+ * This loads, instantiates, and initializes the catalog plugin for each call; it does not cache + * or reuse instances. + * + * @param name a String catalog name + * @param conf a SQLConf + * @return an initialized CatalogPlugin + * @throws CatalogNotFoundException if the plugin class cannot be found + * @throws SparkException if the plugin class cannot be instantiated + */ + public static CatalogPlugin load(String name, SQLConf conf) + throws CatalogNotFoundException, SparkException { + String pluginClassName = conf.getConfString("spark.sql.catalog." + name, null); + if (pluginClassName == null) { + throw new CatalogNotFoundException(String.format( + "Catalog '%s' plugin class not found: spark.sql.catalog.%s is not defined", name, name)); + } + + ClassLoader loader = Utils.getContextOrSparkClassLoader(); + + try { + Class pluginClass = loader.loadClass(pluginClassName); + + if (!CatalogPlugin.class.isAssignableFrom(pluginClass)) { + throw new SparkException(String.format( + "Plugin class for catalog '%s' does not implement CatalogPlugin: %s", + name, pluginClassName)); + } + + CatalogPlugin plugin = CatalogPlugin.class.cast(pluginClass.newInstance()); + + plugin.initialize(name, catalogOptions(name, conf)); + + return plugin; + + } catch (ClassNotFoundException e) { + throw new SparkException(String.format( + "Cannot find catalog plugin class for catalog '%s': %s", name, pluginClassName)); + + } catch (IllegalAccessException e) { + throw new SparkException(String.format( + "Failed to call public no-arg constructor for catalog '%s': %s", name, pluginClassName), + e); + + } catch (InstantiationException e) { + throw new SparkException(String.format( + "Failed while instantiating plugin for catalog '%s': %s", name, pluginClassName), + e.getCause()); + } + } + + /** + * Extracts a named catalog's configuration from a SQLConf. + * + * @param name a catalog name + * @param conf a SQLConf + * @return a case insensitive string map of options starting with spark.sql.catalog.(name). + */ + private static CaseInsensitiveStringMap catalogOptions(String name, SQLConf conf) { + Map allConfs = mapAsJavaMapConverter(conf.getAllConfs()).asJava(); + Pattern prefix = Pattern.compile("^spark\\.sql\\.catalog\\." + name + "\\.(.+)"); + + HashMap options = new HashMap<>(); + for (Map.Entry entry : allConfs.entrySet()) { + Matcher matcher = prefix.matcher(entry.getKey()); + if (matcher.matches() && matcher.groupCount() > 0) { + options.put(matcher.group(1), entry.getValue()); + } + } + + return new CaseInsensitiveStringMap(options); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Identifier.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Identifier.java new file mode 100644 index 0000000000000..3e697c1945bfc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Identifier.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.annotation.Experimental; + +/** + * Identifies an object in a catalog. + */ +@Experimental +public interface Identifier { + + static Identifier of(String[] namespace, String name) { + return new IdentifierImpl(namespace, name); + } + + /** + * @return the namespace in the catalog + */ + String[] namespace(); + + /** + * @return the object name + */ + String name(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java new file mode 100644 index 0000000000000..34f3882c9c412 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import com.google.common.base.Preconditions; +import org.apache.spark.annotation.Experimental; + +import java.util.Arrays; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * An {@link Identifier} implementation. + */ +@Experimental +class IdentifierImpl implements Identifier { + + private String[] namespace; + private String name; + + IdentifierImpl(String[] namespace, String name) { + Preconditions.checkNotNull(namespace, "Identifier namespace cannot be null"); + Preconditions.checkNotNull(name, "Identifier name cannot be null"); + this.namespace = namespace; + this.name = name; + } + + @Override + public String[] namespace() { + return namespace; + } + + @Override + public String name() { + return name; + } + + private String escapeQuote(String part) { + if (part.contains("`")) { + return part.replace("`", "``"); + } else { + return part; + } + } + + @Override + public String toString() { + return Stream.concat(Stream.of(namespace), Stream.of(name)) + .map(part -> '`' + escapeQuote(part) + '`') + .collect(Collectors.joining(".")); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + IdentifierImpl that = (IdentifierImpl) o; + return Arrays.equals(namespace, that.namespace) && name.equals(that.name); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(namespace), name); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java new file mode 100644 index 0000000000000..681629d2d5405 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.sql.catalog.v2.expressions.Transform; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.types.StructType; + +import java.util.Map; + +/** + * Catalog methods for working with Tables. + *

+ * TableCatalog implementations may be case sensitive or case insensitive. Spark will pass + * {@link Identifier table identifiers} without modification. Field names passed to + * {@link #alterTable(Identifier, TableChange...)} will be normalized to match the case used in the + * table schema when updating, renaming, or dropping existing columns when catalyst analysis is case + * insensitive. + */ +public interface TableCatalog extends CatalogPlugin { + /** + * List the tables in a namespace from the catalog. + *

+ * If the catalog supports views, this must return identifiers for only tables and not views. + * + * @param namespace a multi-part namespace + * @return an array of Identifiers for tables + * @throws NoSuchNamespaceException If the namespace does not exist (optional). + */ + Identifier[] listTables(String[] namespace) throws NoSuchNamespaceException; + + /** + * Load table metadata by {@link Identifier identifier} from the catalog. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must throw {@link NoSuchTableException}. + * + * @param ident a table identifier + * @return the table's metadata + * @throws NoSuchTableException If the table doesn't exist or is a view + */ + Table loadTable(Identifier ident) throws NoSuchTableException; + + /** + * Invalidate cached table metadata for an {@link Identifier identifier}. + *

+ * If the table is already loaded or cached, drop cached data. If the table does not exist or is + * not cached, do nothing. Calling this method should not query remote services. + * + * @param ident a table identifier + */ + default void invalidateTable(Identifier ident) { + } + + /** + * Test whether a table exists using an {@link Identifier identifier} from the catalog. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must return false. + * + * @param ident a table identifier + * @return true if the table exists, false otherwise + */ + default boolean tableExists(Identifier ident) { + try { + return loadTable(ident) != null; + } catch (NoSuchTableException e) { + return false; + } + } + + /** + * Create a table in the catalog. + * + * @param ident a table identifier + * @param schema the schema of the new table, as a struct type + * @param partitions transforms to use for partitioning data in the table + * @param properties a string map of table properties + * @return metadata for the new table + * @throws TableAlreadyExistsException If a table or view already exists for the identifier + * @throws UnsupportedOperationException If a requested partition transform is not supported + * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) + */ + Table createTable( + Identifier ident, + StructType schema, + Transform[] partitions, + Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException; + + /** + * Apply a set of {@link TableChange changes} to a table in the catalog. + *

+ * Implementations may reject the requested changes. If any change is rejected, none of the + * changes should be applied to the table. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must throw {@link NoSuchTableException}. + * + * @param ident a table identifier + * @param changes changes to apply to the table + * @return updated metadata for the table + * @throws NoSuchTableException If the table doesn't exist or is a view + * @throws IllegalArgumentException If any change is rejected by the implementation. + */ + Table alterTable( + Identifier ident, + TableChange... changes) throws NoSuchTableException; + + /** + * Drop a table in the catalog. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must not drop the view and must return false. + * + * @param ident a table identifier + * @return true if a table was deleted, false if no table exists for the identifier + */ + boolean dropTable(Identifier ident); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java new file mode 100644 index 0000000000000..9b87e676d9b2d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.sql.types.DataType; + +/** + * TableChange subclasses represent requested changes to a table. These are passed to + * {@link TableCatalog#alterTable}. For example, + *

+ *   import TableChange._
+ *   val catalog = Catalogs.load(name)
+ *   catalog.asTableCatalog.alterTable(ident,
+ *       addColumn("x", IntegerType),
+ *       renameColumn("a", "b"),
+ *       deleteColumn("c")
+ *     )
+ * 
+ */ +public interface TableChange { + + /** + * Create a TableChange for setting a table property. + *

+ * If the property already exists, it will be replaced with the new value. + * + * @param property the property name + * @param value the new property value + * @return a TableChange for the addition + */ + static TableChange setProperty(String property, String value) { + return new SetProperty(property, value); + } + + /** + * Create a TableChange for removing a table property. + *

+ * If the property does not exist, the change will succeed. + * + * @param property the property name + * @return a TableChange for the addition + */ + static TableChange removeProperty(String property) { + return new RemoveProperty(property); + } + + /** + * Create a TableChange for adding an optional column. + *

+ * If the field already exists, the change will result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change will + * result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the new column + * @param dataType the new column's data type + * @return a TableChange for the addition + */ + static TableChange addColumn(String[] fieldNames, DataType dataType) { + return new AddColumn(fieldNames, dataType, true, null); + } + + /** + * Create a TableChange for adding a column. + *

+ * If the field already exists, the change will result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change will + * result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the new column + * @param dataType the new column's data type + * @param isNullable whether the new column can contain null + * @return a TableChange for the addition + */ + static TableChange addColumn(String[] fieldNames, DataType dataType, boolean isNullable) { + return new AddColumn(fieldNames, dataType, isNullable, null); + } + + /** + * Create a TableChange for adding a column. + *

+ * If the field already exists, the change will result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change will + * result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the new column + * @param dataType the new column's data type + * @param isNullable whether the new column can contain null + * @param comment the new field's comment string + * @return a TableChange for the addition + */ + static TableChange addColumn( + String[] fieldNames, + DataType dataType, + boolean isNullable, + String comment) { + return new AddColumn(fieldNames, dataType, isNullable, comment); + } + + /** + * Create a TableChange for renaming a field. + *

+ * The name is used to find the field to rename. The new name will replace the leaf field name. + * For example, renameColumn(["a", "b", "c"], "x") should produce column a.b.x. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames the current field names + * @param newName the new name + * @return a TableChange for the rename + */ + static TableChange renameColumn(String[] fieldNames, String newName) { + return new RenameColumn(fieldNames, newName); + } + + /** + * Create a TableChange for updating the type of a field that is nullable. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to update + * @param newDataType the new data type + * @return a TableChange for the update + */ + static TableChange updateColumnType(String[] fieldNames, DataType newDataType) { + return new UpdateColumnType(fieldNames, newDataType, true); + } + + /** + * Create a TableChange for updating the type of a field. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to update + * @param newDataType the new data type + * @return a TableChange for the update + */ + static TableChange updateColumnType( + String[] fieldNames, + DataType newDataType, + boolean isNullable) { + return new UpdateColumnType(fieldNames, newDataType, isNullable); + } + + /** + * Create a TableChange for updating the comment of a field. + *

+ * The name is used to find the field to update. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to update + * @param newComment the new comment + * @return a TableChange for the update + */ + static TableChange updateColumnComment(String[] fieldNames, String newComment) { + return new UpdateColumnComment(fieldNames, newComment); + } + + /** + * Create a TableChange for deleting a field. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to delete + * @return a TableChange for the delete + */ + static TableChange deleteColumn(String[] fieldNames) { + return new DeleteColumn(fieldNames); + } + + /** + * A TableChange to set a table property. + *

+ * If the property already exists, it must be replaced with the new value. + */ + final class SetProperty implements TableChange { + private final String property; + private final String value; + + private SetProperty(String property, String value) { + this.property = property; + this.value = value; + } + + public String property() { + return property; + } + + public String value() { + return value; + } + } + + /** + * A TableChange to remove a table property. + *

+ * If the property does not exist, the change should succeed. + */ + final class RemoveProperty implements TableChange { + private final String property; + + private RemoveProperty(String property) { + this.property = property; + } + + public String property() { + return property; + } + } + + /** + * A TableChange to add a field. + *

+ * If the field already exists, the change must result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change must + * result in an {@link IllegalArgumentException}. + */ + final class AddColumn implements TableChange { + private final String[] fieldNames; + private final DataType dataType; + private final boolean isNullable; + private final String comment; + + private AddColumn(String[] fieldNames, DataType dataType, boolean isNullable, String comment) { + this.fieldNames = fieldNames; + this.dataType = dataType; + this.isNullable = isNullable; + this.comment = comment; + } + + public String[] fieldNames() { + return fieldNames; + } + + public DataType dataType() { + return dataType; + } + + public boolean isNullable() { + return isNullable; + } + + public String comment() { + return comment; + } + } + + /** + * A TableChange to rename a field. + *

+ * The name is used to find the field to rename. The new name will replace the leaf field name. + * For example, renameColumn("a.b.c", "x") should produce column a.b.x. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class RenameColumn implements TableChange { + private final String[] fieldNames; + private final String newName; + + private RenameColumn(String[] fieldNames, String newName) { + this.fieldNames = fieldNames; + this.newName = newName; + } + + public String[] fieldNames() { + return fieldNames; + } + + public String newName() { + return newName; + } + } + + /** + * A TableChange to update the type of a field. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class UpdateColumnType implements TableChange { + private final String[] fieldNames; + private final DataType newDataType; + private final boolean isNullable; + + private UpdateColumnType(String[] fieldNames, DataType newDataType, boolean isNullable) { + this.fieldNames = fieldNames; + this.newDataType = newDataType; + this.isNullable = isNullable; + } + + public String[] fieldNames() { + return fieldNames; + } + + public DataType newDataType() { + return newDataType; + } + + public boolean isNullable() { + return isNullable; + } + } + + /** + * A TableChange to update the comment of a field. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class UpdateColumnComment implements TableChange { + private final String[] fieldNames; + private final String newComment; + + private UpdateColumnComment(String[] fieldNames, String newComment) { + this.fieldNames = fieldNames; + this.newComment = newComment; + } + + public String[] fieldNames() { + return fieldNames; + } + + public String newComment() { + return newComment; + } + } + + /** + * A TableChange to delete a field. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class DeleteColumn implements TableChange { + private final String[] fieldNames; + + private DeleteColumn(String[] fieldNames) { + this.fieldNames = fieldNames; + } + + public String[] fieldNames() { + return fieldNames; + } + } + +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java index 43bdcca70cb09..1e2aca9556df4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java @@ -15,12 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.catalog.v2.expressions; -import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; /** - * TODO: remove it when we finish the API refactor for streaming write side. + * Base class of the public logical expression API. */ -@Evolving -public interface DataSourceV2 {} +@Experimental +public interface Expression { + /** + * Format the expression as a human readable SQL-like string. + */ + String describe(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java new file mode 100644 index 0000000000000..d8e49beb0bca5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions; + +import java.util.Arrays; + +import scala.collection.JavaConverters; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.types.DataType; + +/** + * Helper methods to create logical transforms to pass into Spark. + */ +@Experimental +public class Expressions { + private Expressions() { + } + + /** + * Create a logical transform for applying a named transform. + *

+ * This transform can represent applying any named transform. + * + * @param name the transform name + * @param args expression arguments to the transform + * @return a logical transform + */ + public static Transform apply(String name, Expression... args) { + return LogicalExpressions.apply(name, + JavaConverters.asScalaBufferConverter(Arrays.asList(args)).asScala()); + } + + /** + * Create a named reference expression for a column. + * + * @param name a column name + * @return a named reference for the column + */ + public static NamedReference column(String name) { + return LogicalExpressions.reference(name); + } + + /** + * Create a literal from a value. + *

+ * The JVM type of the value held by a literal must be the type used by Spark's InternalRow API + * for the literal's {@link DataType SQL data type}. + * + * @param value a value + * @param the JVM type of the value + * @return a literal expression for the value + */ + public static Literal literal(T value) { + return LogicalExpressions.literal(value); + } + + /** + * Create a bucket transform for one or more columns. + *

+ * This transform represents a logical mapping from a value to a bucket id in [0, numBuckets) + * based on a hash of the value. + *

+ * The name reported by transforms created with this method is "bucket". + * + * @param numBuckets the number of output buckets + * @param columns input columns for the bucket transform + * @return a logical bucket transform with name "bucket" + */ + public static Transform bucket(int numBuckets, String... columns) { + return LogicalExpressions.bucket(numBuckets, + JavaConverters.asScalaBufferConverter(Arrays.asList(columns)).asScala()); + } + + /** + * Create an identity transform for a column. + *

+ * This transform represents a logical mapping from a value to itself. + *

+ * The name reported by transforms created with this method is "identity". + * + * @param column an input column + * @return a logical identity transform with name "identity" + */ + public static Transform identity(String column) { + return LogicalExpressions.identity(column); + } + + /** + * Create a yearly transform for a timestamp or date column. + *

+ * This transform represents a logical mapping from a timestamp or date to a year, such as 2018. + *

+ * The name reported by transforms created with this method is "years". + * + * @param column an input timestamp or date column + * @return a logical yearly transform with name "years" + */ + public static Transform years(String column) { + return LogicalExpressions.years(column); + } + + /** + * Create a monthly transform for a timestamp or date column. + *

+ * This transform represents a logical mapping from a timestamp or date to a month, such as + * 2018-05. + *

+ * The name reported by transforms created with this method is "months". + * + * @param column an input timestamp or date column + * @return a logical monthly transform with name "months" + */ + public static Transform months(String column) { + return LogicalExpressions.months(column); + } + + /** + * Create a daily transform for a timestamp or date column. + *

+ * This transform represents a logical mapping from a timestamp or date to a date, such as + * 2018-05-13. + *

+ * The name reported by transforms created with this method is "days". + * + * @param column an input timestamp or date column + * @return a logical daily transform with name "days" + */ + public static Transform days(String column) { + return LogicalExpressions.days(column); + } + + /** + * Create an hourly transform for a timestamp column. + *

+ * This transform represents a logical mapping from a timestamp to a date and hour, such as + * 2018-05-13, hour 19. + *

+ * The name reported by transforms created with this method is "hours". + * + * @param column an input timestamp column + * @return a logical hourly transform with name "hours" + */ + public static Transform hours(String column) { + return LogicalExpressions.hours(column); + } + +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java similarity index 56% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java index b7fa3f24a238c..e41bcf9000c52 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java @@ -15,20 +15,28 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.catalog.v2.expressions; -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.types.DataType; /** - * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with - * continuous mode. + * Represents a constant literal value in the public expression API. *

- * If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that - * builds {@link Scan} with {@link Scan#toContinuousStream(String)} implemented. - *

+ * The JVM type of the value held by a literal must be the type used by Spark's InternalRow API for + * the literal's {@link DataType SQL data type}. + * + * @param the JVM type of a value held by the literal */ -@Evolving -public interface SupportsContinuousRead extends SupportsRead { } +@Experimental +public interface Literal extends Expression { + /** + * Returns the literal value. + */ + T value(); + + /** + * Returns the SQL data type of the literal. + */ + DataType dataType(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java new file mode 100644 index 0000000000000..c71ffbe70651f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions; + +import org.apache.spark.annotation.Experimental; + +/** + * Represents a field or column reference in the public logical expression API. + */ +@Experimental +public interface NamedReference extends Expression { + /** + * Returns the referenced field name as an array of String parts. + *

+ * Each string in the returned array represents a field name. + */ + String[] fieldNames(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java similarity index 54% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java index 9408e323f9da1..c85e0c412f1ab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java @@ -15,20 +15,30 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.catalog.v2.expressions; -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.annotation.Experimental; /** - * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with - * micro-batch mode. + * Represents a transform function in the public logical expression API. *

- * If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that - * builds {@link Scan} with {@link Scan#toMicroBatchStream(String)} implemented. - *

+ * For example, the transform date(ts) is used to derive a date value from a timestamp column. The + * transform name is "date" and its argument is a reference to the "ts" column. */ -@Evolving -public interface SupportsMicroBatchRead extends SupportsRead { } +@Experimental +public interface Transform extends Expression { + /** + * Returns the transform function name. + */ + String name(); + + /** + * Returns all field references in the transform arguments. + */ + NamedReference[] references(); + + /** + * Returns the arguments passed to the transform function. + */ + Expression[] arguments(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java similarity index 89% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index c00abd9b685b5..d27fbfdd14617 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -20,12 +20,12 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * A mix-in interface for {@link TableProvider}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this * session. */ @Evolving -public interface SessionConfigSupport extends DataSourceV2 { +public interface SessionConfigSupport extends TableProvider { /** * Key prefix of the session configs to propagate, which is usually the data source name. Spark diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java similarity index 76% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java index 5031c71c0fd4d..826fa2f8a0720 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java @@ -19,13 +19,14 @@ import org.apache.spark.sql.sources.v2.reader.Scan; import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** - * An internal base interface of mix-in interfaces for readable {@link Table}. This adds - * {@link #newScanBuilder(DataSourceOptions)} that is used to create a scan for batch, micro-batch, - * or continuous processing. + * A mix-in interface of {@link Table}, to indicate that it's readable. This adds + * {@link #newScanBuilder(CaseInsensitiveStringMap)} that is used to create a scan for batch, + * micro-batch, or continuous processing. */ -interface SupportsRead extends Table { +public interface SupportsRead extends Table { /** * Returns a {@link ScanBuilder} which can be used to build a {@link Scan}. Spark will call this @@ -34,5 +35,5 @@ interface SupportsRead extends Table { * @param options The options for reading, which is an immutable case-insensitive * string-to-string map. */ - ScanBuilder newScanBuilder(DataSourceOptions options); + ScanBuilder newScanBuilder(CaseInsensitiveStringMap options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java similarity index 77% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java index ecdfe20730254..c52e54569dc0c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java @@ -19,17 +19,18 @@ import org.apache.spark.sql.sources.v2.writer.BatchWrite; import org.apache.spark.sql.sources.v2.writer.WriteBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** - * An internal base interface of mix-in interfaces for writable {@link Table}. This adds - * {@link #newWriteBuilder(DataSourceOptions)} that is used to create a write + * A mix-in interface of {@link Table}, to indicate that it's writable. This adds + * {@link #newWriteBuilder(CaseInsensitiveStringMap)} that is used to create a write * for batch or streaming. */ -interface SupportsWrite extends Table { +public interface SupportsWrite extends Table { /** * Returns a {@link WriteBuilder} which can be used to create {@link BatchWrite}. Spark will call * this method to configure each data source write. */ - WriteBuilder newWriteBuilder(DataSourceOptions options); + WriteBuilder newWriteBuilder(CaseInsensitiveStringMap options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java similarity index 65% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java index 08664859b8de2..482d3c22e2306 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java @@ -18,18 +18,24 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalog.v2.expressions.Transform; import org.apache.spark.sql.types.StructType; +import java.util.Collections; +import java.util.Map; +import java.util.Set; + /** * An interface representing a logical structured data set of a data source. For example, the * implementation can be a directory on the file system, a topic of Kafka, or a table in the * catalog, etc. *

- * This interface can mixin the following interfaces to support different operations: - *

- *
    - *
  • {@link SupportsBatchRead}: this table can be read in batch queries.
  • - *
+ * This interface can mixin the following interfaces to support different operations, like + * {@code SupportsRead}. + *

+ * The default implementation of {@link #partitioning()} returns an empty array of partitions, and + * the default implementation of {@link #properties()} returns an empty map. These should be + * overridden by implementations that support partitioning and table properties. */ @Evolving public interface Table { @@ -45,4 +51,23 @@ public interface Table { * empty schema can be returned here. */ StructType schema(); + + /** + * Returns the physical partitioning of this table. + */ + default Transform[] partitioning() { + return new Transform[0]; + } + + /** + * Returns the string map of table properties. + */ + default Map properties() { + return Collections.emptyMap(); + } + + /** + * Returns the set of capabilities for this table. + */ + Set capabilities(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java new file mode 100644 index 0000000000000..c44a12b174f4c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Experimental; + +/** + * Capabilities that can be provided by a {@link Table} implementation. + *

+ * Tables use {@link Table#capabilities()} to return a set of capabilities. Each capability signals + * to Spark that the table supports a feature identified by the capability. For example, returning + * {@code BATCH_READ} allows Spark to read from the table using a batch scan. + */ +@Experimental +public enum TableCapability { + /** + * Signals that the table supports reads in batch execution mode. + */ + BATCH_READ, + + /** + * Signals that the table supports reads in micro-batch streaming execution mode. + */ + MICRO_BATCH_READ, + + /** + * Signals that the table supports reads in continuous streaming execution mode. + */ + CONTINUOUS_READ, + + /** + * Signals that the table supports append writes in batch execution mode. + *

+ * Tables that return this capability must support appending data and may also support additional + * write modes, like {@link #TRUNCATE}, {@link #OVERWRITE_BY_FILTER}, and + * {@link #OVERWRITE_DYNAMIC}. + */ + BATCH_WRITE, + + /** + * Signals that the table supports append writes in streaming execution mode. + *

+ * Tables that return this capability must support appending data and may also support additional + * write modes, like {@link #TRUNCATE}, {@link #OVERWRITE_BY_FILTER}, and + * {@link #OVERWRITE_DYNAMIC}. + */ + STREAMING_WRITE, + + /** + * Signals that the table can be truncated in a write operation. + *

+ * Truncating a table removes all existing rows. + *

+ * See {@code org.apache.spark.sql.sources.v2.writer.SupportsTruncate}. + */ + TRUNCATE, + + /** + * Signals that the table can replace existing data that matches a filter with appended data in + * a write operation. + *

+ * See {@code org.apache.spark.sql.sources.v2.writer.SupportsOverwrite}. + */ + OVERWRITE_BY_FILTER, + + /** + * Signals that the table can dynamically replace existing data partitions with appended data in + * a write operation. + *

+ * See {@code org.apache.spark.sql.sources.v2.writer.SupportsDynamicOverwrite}. + */ + OVERWRITE_DYNAMIC, + + /** + * Signals that the table accepts input of any schema in a write operation. + */ + ACCEPT_ANY_SCHEMA +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java index 855d5efe0c69f..1d37ff042bd33 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -18,19 +18,22 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * The base interface for v2 data sources which don't have a real catalog. Implementations must * have a public, 0-arg constructor. *

+ * Note that, TableProvider can only apply data operations to existing tables, like read, append, + * delete, and overwrite. It does not support the operations that require metadata changes, like + * create/drop tables. + *

* The major responsibility of this interface is to return a {@link Table} for read/write. *

*/ @Evolving -// TODO: do not extend `DataSourceV2`, after we finish the API refactor completely. -public interface TableProvider extends DataSourceV2 { +public interface TableProvider { /** * Return a {@link Table} instance to do read/write with user-specified options. @@ -38,7 +41,7 @@ public interface TableProvider extends DataSourceV2 { * @param options the user-specified options that can identify a table, e.g. file path, Kafka * topic name, etc. It's an immutable case-insensitive string-to-string map. */ - Table getTable(DataSourceOptions options); + Table getTable(CaseInsensitiveStringMap options); /** * Return a {@link Table} instance to do read/write with user-specified schema and options. @@ -51,14 +54,8 @@ public interface TableProvider extends DataSourceV2 { * @param schema the user-specified schema. * @throws UnsupportedOperationException */ - default Table getTable(DataSourceOptions options, StructType schema) { - String name; - if (this instanceof DataSourceRegister) { - name = ((DataSourceRegister) this).shortName(); - } else { - name = this.getClass().getName(); - } + default Table getTable(CaseInsensitiveStringMap options, StructType schema) { throw new UnsupportedOperationException( - name + " source does not support user-specified schema"); + this.getClass().getSimpleName() + " source does not support user-specified schema"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java index 25ab06eee42e0..ac4f38287a24d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java @@ -21,10 +21,8 @@ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousStream; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.sources.v2.SupportsBatchRead; -import org.apache.spark.sql.sources.v2.SupportsContinuousRead; -import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableCapability; /** * A logical representation of a data source scan. This interface is used to provide logical @@ -33,8 +31,8 @@ * This logical representation is shared between batch scan, micro-batch streaming scan and * continuous streaming scan. Data sources must implement the corresponding methods in this * interface, to match what the table promises to support. For example, {@link #toBatch()} must be - * implemented, if the {@link Table} that creates this {@link Scan} implements - * {@link SupportsBatchRead}. + * implemented, if the {@link Table} that creates this {@link Scan} returns + * {@link TableCapability#BATCH_READ} support in its {@link Table#capabilities()}. *

*/ @Evolving @@ -62,7 +60,8 @@ default String description() { /** * Returns the physical representation of this scan for batch query. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the - * {@link Table} that creates this scan implements {@link SupportsBatchRead}. + * {@link Table} that creates this scan returns {@link TableCapability#BATCH_READ} in its + * {@link Table#capabilities()}. * * @throws UnsupportedOperationException */ @@ -73,8 +72,8 @@ default Batch toBatch() { /** * Returns the physical representation of this scan for streaming query with micro-batch mode. By * default this method throws exception, data sources must overwrite this method to provide an - * implementation, if the {@link Table} that creates this scan implements - * {@link SupportsMicroBatchRead}. + * implementation, if the {@link Table} that creates this scan returns + * {@link TableCapability#MICRO_BATCH_READ} support in its {@link Table#capabilities()}. * * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure * recovery. Data streams for the same logical source in the same query @@ -89,8 +88,8 @@ default MicroBatchStream toMicroBatchStream(String checkpointLocation) { /** * Returns the physical representation of this scan for streaming query with continuous mode. By * default this method throws exception, data sources must overwrite this method to provide an - * implementation, if the {@link Table} that creates this scan implements - * {@link SupportsContinuousRead}. + * implementation, if the {@link Table} that creates this scan returns + * {@link TableCapability#CONTINUOUS_READ} support in its {@link Table#capabilities()}. * * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure * recovery. Data streams for the same logical source in the same query diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 296d3e47e732b..f10fd884daabe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -29,6 +29,9 @@ public interface SupportsPushDownFilters extends ScanBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. + *

+ * Rows should be returned from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. */ Filter[] pushFilters(Filter[] filters); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java similarity index 80% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index a06671383ac5f..1d34fdd1c28ab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -25,13 +25,9 @@ * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. - * - * Note: This class currently extends {@link org.apache.spark.sql.execution.streaming.Offset} to - * maintain compatibility with DataSource V1 APIs. This extension will be removed once we - * get rid of V1 completely. */ @Evolving -public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { +public abstract class Offset { /** * A JSON-serialized representation of an Offset that is * used for saving offsets to the offset log. @@ -49,9 +45,8 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of */ @Override public boolean equals(Object obj) { - if (obj instanceof org.apache.spark.sql.execution.streaming.Offset) { - return this.json() - .equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); + if (obj instanceof Offset) { + return this.json().equals(((Offset) obj).json()); } else { return false; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java index 30f38ce37c401..2068a84fc6bb1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; /** * The base interface representing a readable data stream in a Spark streaming query. It's @@ -28,7 +27,7 @@ * {@link MicroBatchStream} and {@link ContinuousStream}. */ @Evolving -public interface SparkDataStream extends BaseStreamingSource { +public interface SparkDataStream { /** * Returns the initial offset for a streaming query to start reading from. Note that the @@ -50,4 +49,9 @@ public interface SparkDataStream extends BaseStreamingSource { * equal to `end` and will only request offsets greater than `end` in the future. */ void commit(Offset end); + + /** + * Stop this source and free any resources it has allocated. + */ + void stop(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java new file mode 100644 index 0000000000000..8058964b662bd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +/** + * Write builder trait for tables that support dynamic partition overwrite. + *

+ * A write that dynamically overwrites partitions removes all existing data in each logical + * partition for which the write will commit new data. Any existing logical partition for which the + * write does not contain data will remain unchanged. + *

+ * This is provided to implement SQL compatible with Hive table operations but is not recommended. + * Instead, use the {@link SupportsOverwrite overwrite by filter API} to explicitly replace data. + */ +public interface SupportsDynamicOverwrite extends WriteBuilder { + /** + * Configures a write to dynamically replace partitions with data committed in the write. + * + * @return this write builder for method chaining + */ + WriteBuilder overwriteDynamicPartitions(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java new file mode 100644 index 0000000000000..b443b3c3aeb4a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.sql.sources.AlwaysTrue$; +import org.apache.spark.sql.sources.Filter; + +/** + * Write builder trait for tables that support overwrite by filter. + *

+ * Overwriting data by filter will delete any data that matches the filter and replace it with data + * that is committed in the write. + */ +public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate { + /** + * Configures a write to replace data matching the filters with data committed in the write. + *

+ * Rows must be deleted from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. + * + * @param filters filters used to match data to overwrite + * @return this write builder for method chaining + */ + WriteBuilder overwrite(Filter[] filters); + + @Override + default WriteBuilder truncate() { + return overwrite(new Filter[] { AlwaysTrue$.MODULE$ }); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java similarity index 67% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java index c4295f2371877..69c2ba5e01a49 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java @@ -17,10 +17,16 @@ package org.apache.spark.sql.sources.v2.writer; -import org.apache.spark.sql.SaveMode; - -// A temporary mixin trait for `WriteBuilder` to support `SaveMode`. Will be removed before -// Spark 3.0 when all the new write operators are finished. See SPARK-26356 for more details. -public interface SupportsSaveMode extends WriteBuilder { - WriteBuilder mode(SaveMode mode); +/** + * Write builder trait for tables that support truncation. + *

+ * Truncation removes all data in a table and replaces it with data that is committed in the write. + */ +public interface SupportsTruncate extends WriteBuilder { + /** + * Configures a write to replace all existing data with data committed in the write. + * + * @return this write builder for method chaining + */ + WriteBuilder truncate(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java similarity index 83% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java index e861c72af9e68..bfe41f5e8dfb5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java @@ -18,8 +18,8 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.SupportsBatchWrite; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; import org.apache.spark.sql.types.StructType; /** @@ -57,13 +57,16 @@ default WriteBuilder withInputDataSchema(StructType schema) { /** * Returns a {@link BatchWrite} to write data to batch source. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the - * {@link Table} that creates this scan implements {@link SupportsBatchWrite}. - * - * Note that, the returned {@link BatchWrite} can be null if the implementation supports SaveMode, - * to indicate that no writing is needed. We can clean it up after removing - * {@link SupportsSaveMode}. + * {@link Table} that creates this write returns {@link TableCapability#BATCH_WRITE} support in + * its {@link Table#capabilities()}. */ default BatchWrite buildForBatch() { - throw new UnsupportedOperationException("Batch scans are not supported"); + throw new UnsupportedOperationException(getClass().getName() + + " does not support batch write"); + } + + default StreamingWrite buildForStreaming() { + throw new UnsupportedOperationException(getClass().getName() + + " does not support streaming write"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 6334c8f643098..23e8580c404d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -20,12 +20,12 @@ import java.io.Serializable; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side * as the input parameter of {@link BatchWrite#commit(WriterCommitMessage[])} or - * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. + * {@link StreamingWrite#commit(long, WriterCommitMessage[])}. * * This is an empty interface, data sources should define their own message class and use it when * generating messages at executor side and handling the messages at driver side. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java index 7d3d21cb2b637..af2f03c9d4192 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -26,7 +26,7 @@ /** * A factory of {@link DataWriter} returned by - * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating + * {@link StreamingWrite#createStreamingWriterFactory()}, which is responsible for creating * and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java index 84cfbf2dda483..5617f1cdc0efc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java @@ -22,13 +22,26 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * An interface that defines how to write the data to data source for streaming processing. + * An interface that defines how to write the data to data source in streaming queries. * - * Streaming queries are divided into intervals of data called epochs, with a monotonically - * increasing numeric ID. This writer handles commits and aborts for each successive epoch. + * The writing procedure is: + * 1. Create a writer factory by {@link #createStreamingWriterFactory()}, serialize and send it to + * all the partitions of the input data(RDD). + * 2. For each epoch in each partition, create the data writer, and write the data of the epoch in + * the partition with this writer. If all the data are written successfully, call + * {@link DataWriter#commit()}. If exception happens during the writing, call + * {@link DataWriter#abort()}. + * 3. If writers in all partitions of one epoch are successfully committed, call + * {@link #commit(long, WriterCommitMessage[])}. If some writers are aborted, or the job failed + * with an unknown reason, call {@link #abort(long, WriterCommitMessage[])}. + * + * While Spark will retry failed writing tasks, Spark won't retry failed writing jobs. Users should + * do it manually in their Spark applications if they want to retry. + * + * Please refer to the documentation of commit/abort methods for detailed specifications. */ @Evolving -public interface StreamingWriteSupport { +public interface StreamingWrite { /** * Creates a writer factory which will be serialized and sent to executors. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java new file mode 100644 index 0000000000000..da41346d7ce71 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util; + +import org.apache.spark.annotation.Experimental; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +/** + * Case-insensitive map of string keys to string values. + *

+ * This is used to pass options to v2 implementations to ensure consistent case insensitivity. + *

+ * Methods that return keys in this map, like {@link #entrySet()} and {@link #keySet()}, return + * keys converted to lower case. This map doesn't allow null key. + */ +@Experimental +public class CaseInsensitiveStringMap implements Map { + private final Logger logger = LoggerFactory.getLogger(CaseInsensitiveStringMap.class); + + private String unsupportedOperationMsg = "CaseInsensitiveStringMap is read-only."; + + public static CaseInsensitiveStringMap empty() { + return new CaseInsensitiveStringMap(new HashMap<>(0)); + } + + private final Map original; + + private final Map delegate; + + public CaseInsensitiveStringMap(Map originalMap) { + original = new HashMap<>(originalMap); + delegate = new HashMap<>(originalMap.size()); + for (Map.Entry entry : originalMap.entrySet()) { + String key = toLowerCase(entry.getKey()); + if (delegate.containsKey(key)) { + logger.warn("Converting duplicated key " + entry.getKey() + + " into CaseInsensitiveStringMap."); + } + delegate.put(key, entry.getValue()); + } + } + + @Override + public int size() { + return delegate.size(); + } + + @Override + public boolean isEmpty() { + return delegate.isEmpty(); + } + + private String toLowerCase(Object key) { + return key.toString().toLowerCase(Locale.ROOT); + } + + @Override + public boolean containsKey(Object key) { + return delegate.containsKey(toLowerCase(key)); + } + + @Override + public boolean containsValue(Object value) { + return delegate.containsValue(value); + } + + @Override + public String get(Object key) { + return delegate.get(toLowerCase(key)); + } + + @Override + public String put(String key, String value) { + throw new UnsupportedOperationException(unsupportedOperationMsg); + } + + @Override + public String remove(Object key) { + throw new UnsupportedOperationException(unsupportedOperationMsg); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException(unsupportedOperationMsg); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(unsupportedOperationMsg); + } + + @Override + public Set keySet() { + return delegate.keySet(); + } + + @Override + public Collection values() { + return delegate.values(); + } + + @Override + public Set> entrySet() { + return delegate.entrySet(); + } + + /** + * Returns the boolean value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive. + */ + public boolean getBoolean(String key, boolean defaultValue) { + String value = get(key); + // We can't use `Boolean.parseBoolean` here, as it returns false for invalid strings. + if (value == null) { + return defaultValue; + } else if (value.equalsIgnoreCase("true")) { + return true; + } else if (value.equalsIgnoreCase("false")) { + return false; + } else { + throw new IllegalArgumentException(value + " is not a boolean string."); + } + } + + /** + * Returns the integer value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive. + */ + public int getInt(String key, int defaultValue) { + String value = get(key); + return value == null ? defaultValue : Integer.parseInt(value); + } + + /** + * Returns the long value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive. + */ + public long getLong(String key, long defaultValue) { + String value = get(key); + return value == null ? defaultValue : Long.parseLong(value); + } + + /** + * Returns the double value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive. + */ + public double getDouble(String key, double defaultValue) { + String value = get(key); + return value == null ? defaultValue : Double.parseDouble(value); + } + + /** + * Returns the original case-sensitive map. + */ + public Map asCaseSensitiveMap() { + return Collections.unmodifiableMap(original); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 99% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 906e9bc26ef53..07d17ee14ce23 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -23,7 +23,7 @@ import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.arrow.ArrowUtils; +import org.apache.spark.sql.util.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java new file mode 100644 index 0000000000000..9f917ea11d72a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized; + +import java.util.*; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. + */ +@Evolving +public final class ColumnarBatch { + private int numRows; + private final ColumnVector[] columns; + + // Staging row returned from `getRow`. + private final ColumnarBatchRow row; + + /** + * Called to close all the columns in this batch. It is not valid to access the data after + * calling this. This must be called at the end to clean up memory allocations. + */ + public void close() { + for (ColumnVector c: columns) { + c.close(); + } + } + + /** + * Returns an iterator over the rows in this batch. + */ + public Iterator rowIterator() { + final int maxRows = numRows; + final ColumnarBatchRow row = new ColumnarBatchRow(columns); + return new Iterator() { + int rowId = 0; + + @Override + public boolean hasNext() { + return rowId < maxRows; + } + + @Override + public InternalRow next() { + if (rowId >= maxRows) { + throw new NoSuchElementException(); + } + row.rowId = rowId++; + return row; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Sets the number of rows in this batch. + */ + public void setNumRows(int numRows) { + this.numRows = numRows; + } + + /** + * Returns the number of columns that make up this batch. + */ + public int numCols() { return columns.length; } + + /** + * Returns the number of rows for read, including filtered rows. + */ + public int numRows() { return numRows; } + + /** + * Returns the column at `ordinal`. + */ + public ColumnVector column(int ordinal) { return columns[ordinal]; } + + /** + * Returns the row in this batch at `rowId`. Returned row is reused across calls. + */ + public InternalRow getRow(int rowId) { + assert(rowId >= 0 && rowId < numRows); + row.rowId = rowId; + return row; + } + + public ColumnarBatch(ColumnVector[] columns) { + this.columns = columns; + this.row = new ColumnarBatchRow(columns); + } +} + +/** + * An internal class, which wraps an array of {@link ColumnVector} and provides a row view. + */ +class ColumnarBatchRow extends InternalRow { + public int rowId; + private final ColumnVector[] columns; + + ColumnarBatchRow(ColumnVector[] columns) { + this.columns = columns; + } + + @Override + public int numFields() { return columns.length; } + + @Override + public InternalRow copy() { + GenericInternalRow row = new GenericInternalRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i).copy()); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return columns[ordinal].getInterval(rowId); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return columns[ordinal].getStruct(rowId); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return columns[ordinal].getArray(rowId); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return columns[ordinal].getMap(rowId); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + @Override + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala similarity index 72% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala index ac96c2765368f..86de1c9285b73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming; +package org.apache.spark.sql.catalog.v2 -/** - * The shared interface between V1 and V2 streaming sinks. - * - * This is a temporary interface for compatibility during migration. It should not be implemented - * directly, and will be removed in future versions. - */ -public interface BaseStreamingSink { +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental + +@Experimental +class CatalogNotFoundException(message: String, cause: Throwable) + extends SparkException(message, cause) { + + def this(message: String) = this(message, null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala new file mode 100644 index 0000000000000..f512cd5e23c6b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.types.StructType + +/** + * Conversion helpers for working with v2 [[CatalogPlugin]]. + */ +object CatalogV2Implicits { + implicit class PartitionTypeHelper(partitionType: StructType) { + def asTransforms: Array[Transform] = partitionType.names.map(LogicalExpressions.identity) + } + + implicit class BucketSpecHelper(spec: BucketSpec) { + def asTransform: BucketTransform = { + if (spec.sortColumnNames.nonEmpty) { + throw new AnalysisException( + s"Cannot convert bucketing with sort columns to a transform: $spec") + } + + LogicalExpressions.bucket(spec.numBuckets, spec.bucketColumnNames: _*) + } + } + + implicit class TransformHelper(transforms: Seq[Transform]) { + def asPartitionColumns: Seq[String] = { + val (idTransforms, nonIdTransforms) = transforms.partition(_.isInstanceOf[IdentityTransform]) + + if (nonIdTransforms.nonEmpty) { + throw new AnalysisException("Transforms cannot be converted to partition columns: " + + nonIdTransforms.map(_.describe).mkString(", ")) + } + + idTransforms.map(_.asInstanceOf[IdentityTransform]).map(_.reference).map { ref => + val parts = ref.fieldNames + if (parts.size > 1) { + throw new AnalysisException(s"Cannot partition by nested column: $ref") + } else { + parts(0) + } + } + } + } + + implicit class CatalogHelper(plugin: CatalogPlugin) { + def asTableCatalog: TableCatalog = plugin match { + case tableCatalog: TableCatalog => + tableCatalog + case _ => + throw new AnalysisException(s"Cannot use catalog ${plugin.name}: not a TableCatalog") + } + } + + implicit class NamespaceHelper(namespace: Array[String]) { + def quoted: String = namespace.map(quote).mkString(".") + } + + implicit class IdentifierHelper(ident: Identifier) { + def quoted: String = { + if (ident.namespace.nonEmpty) { + ident.namespace.map(quote).mkString(".") + "." + quote(ident.name) + } else { + quote(ident.name) + } + } + } + + implicit class MultipartIdentifierHelper(namespace: Seq[String]) { + def quoted: String = namespace.map(quote).mkString(".") + } + + private def quote(part: String): String = { + if (part.contains(".") || part.contains("`")) { + s"`${part.replace("`", "``")}`" + } else { + part + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala new file mode 100644 index 0000000000000..5464a7496d23d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * A trait to encapsulate catalog lookup function and helpful extractors. + */ +@Experimental +trait LookupCatalog { + + protected def lookupCatalog(name: String): CatalogPlugin + + type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier) + + /** + * Extract catalog plugin and identifier from a multi-part identifier. + */ + object CatalogObjectIdentifier { + def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match { + case Seq(name) => + Some((None, Identifier.of(Array.empty, name))) + case Seq(catalogName, tail @ _*) => + try { + Some((Some(lookupCatalog(catalogName)), Identifier.of(tail.init.toArray, tail.last))) + } catch { + case _: CatalogNotFoundException => + Some((None, Identifier.of(parts.init.toArray, parts.last))) + } + } + } + + /** + * Extract legacy table identifier from a multi-part identifier. + * + * For legacy support only. Please use [[CatalogObjectIdentifier]] instead on DSv2 code paths. + */ + object AsTableIdentifier { + def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match { + case CatalogObjectIdentifier(None, ident) => + ident.namespace match { + case Array() => + Some(TableIdentifier(ident.name)) + case Array(database) => + Some(TableIdentifier(ident.name, Some(database))) + case _ => + None + } + case _ => + None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala new file mode 100644 index 0000000000000..2d4d6e7c6d5ee --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions + +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} + +/** + * Helper methods for working with the logical expressions API. + * + * Factory methods can be used when referencing the logical expression nodes is ambiguous because + * logical and internal expressions are used. + */ +private[sql] object LogicalExpressions { + // a generic parser that is only used for parsing multi-part field names. + // because this is only used for field names, the SQL conf passed in does not matter. + private lazy val parser = new CatalystSqlParser(SQLConf.get) + + def literal[T](value: T): LiteralValue[T] = { + val internalLit = catalyst.expressions.Literal(value) + literal(value, internalLit.dataType) + } + + def literal[T](value: T, dataType: DataType): LiteralValue[T] = LiteralValue(value, dataType) + + def reference(name: String): NamedReference = + FieldReference(parser.parseMultipartIdentifier(name)) + + def apply(name: String, arguments: Expression*): Transform = ApplyTransform(name, arguments) + + def bucket(numBuckets: Int, columns: String*): BucketTransform = + BucketTransform(literal(numBuckets, IntegerType), columns.map(reference)) + + def identity(column: String): IdentityTransform = IdentityTransform(reference(column)) + + def years(column: String): YearsTransform = YearsTransform(reference(column)) + + def months(column: String): MonthsTransform = MonthsTransform(reference(column)) + + def days(column: String): DaysTransform = DaysTransform(reference(column)) + + def hours(column: String): HoursTransform = HoursTransform(reference(column)) +} + +/** + * Base class for simple transforms of a single column. + */ +private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends Transform { + + def reference: NamedReference = ref + + override def references: Array[NamedReference] = Array(ref) + + override def arguments: Array[Expression] = Array(ref) + + override def describe: String = name + "(" + reference.describe + ")" + + override def toString: String = describe +} + +private[sql] final case class BucketTransform( + numBuckets: Literal[Int], + columns: Seq[NamedReference]) extends Transform { + + override val name: String = "bucket" + + override def references: Array[NamedReference] = { + arguments + .filter(_.isInstanceOf[NamedReference]) + .map(_.asInstanceOf[NamedReference]) + } + + override def arguments: Array[Expression] = numBuckets +: columns.toArray + + override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" + + override def toString: String = describe +} + +private[sql] final case class ApplyTransform( + name: String, + args: Seq[Expression]) extends Transform { + + override def arguments: Array[Expression] = args.toArray + + override def references: Array[NamedReference] = { + arguments + .filter(_.isInstanceOf[NamedReference]) + .map(_.asInstanceOf[NamedReference]) + } + + override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})" + + override def toString: String = describe +} + +private[sql] final case class IdentityTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "identity" + override def describe: String = ref.describe +} + +private[sql] final case class YearsTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "years" +} + +private[sql] final case class MonthsTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "months" +} + +private[sql] final case class DaysTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "days" +} + +private[sql] final case class HoursTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "hours" +} + +private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { + override def describe: String = { + if (dataType.isInstanceOf[StringType]) { + s"'$value'" + } else { + s"$value" + } + } + override def toString: String = describe +} + +private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.MultipartIdentifierHelper + override def fieldNames: Array[String] = parts.toArray + override def describe: String = parts.quoted + override def toString: String = describe +} + +private[sql] object FieldReference { + def apply(column: String): NamedReference = { + LogicalExpressions.reference(column) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 6d587abd8fd4d..f5e9a146bf359 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ +import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec /** @@ -25,13 +27,26 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ class DatabaseAlreadyExistsException(db: String) - extends AnalysisException(s"Database '$db' already exists") + extends NamespaceAlreadyExistsException(s"Database '$db' already exists") -class TableAlreadyExistsException(db: String, table: String) - extends AnalysisException(s"Table or view '$table' already exists in database '$db'") +class NamespaceAlreadyExistsException(message: String) extends AnalysisException(message) { + def this(namespace: Array[String]) = { + this(s"Namespace '${namespace.quoted}' already exists") + } +} + +class TableAlreadyExistsException(message: String) extends AnalysisException(message) { + def this(db: String, table: String) = { + this(s"Table or view '$table' already exists in database '$db'") + } + + def this(tableIdent: Identifier) = { + this(s"Table ${tableIdent.quoted} already exists") + } +} class TempTableAlreadyExistsException(table: String) - extends AnalysisException(s"Temporary view '$table' already exists") + extends TableAlreadyExistsException(s"Temporary view '$table' already exists") class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec) extends AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a84bb7653c527..e0c0ad6efb483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -96,12 +97,15 @@ class Analyzer( catalog: SessionCatalog, conf: SQLConf, maxIterations: Int) - extends RuleExecutor[LogicalPlan] with CheckAnalysis { + extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog { def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } + override protected def lookupCatalog(name: String): CatalogPlugin = + throw new CatalogNotFoundException("No catalog lookup function") + def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { AnalysisHelper.markInAnalyzer { val analyzed = executeAndTrack(plan, tracker) @@ -978,6 +982,11 @@ class Analyzer( case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) => a.mapExpressions(resolveExpressionTopDown(_, appendColumns)) + case o: OverwriteByExpression if !o.outputResolved => + // do not resolve expression attributes until the query attributes are resolved against the + // table by ResolveOutputRelation. that rule will alias the attributes to the table's names. + o + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") q.mapExpressions(resolveExpressionTopDown(_, q)) @@ -2237,7 +2246,7 @@ class Analyzer( object ResolveOutputRelation extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case append @ AppendData(table, query, isByName) - if table.resolved && query.resolved && !append.resolved => + if table.resolved && query.resolved && !append.outputResolved => val projection = resolveOutputColumns(table.name, table.output, query, isByName) if (projection != query) { @@ -2245,6 +2254,26 @@ class Analyzer( } else { append } + + case overwrite @ OverwriteByExpression(table, _, query, isByName) + if table.resolved && query.resolved && !overwrite.outputResolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + overwrite.copy(query = projection) + } else { + overwrite + } + + case overwrite @ OverwritePartitionsDynamic(table, query, isByName) + if table.resolved && query.resolved && !overwrite.outputResolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + overwrite.copy(query = projection) + } else { + overwrite + } } def resolveOutputColumns( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 18c40b370cb5f..fcb2eec609c28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.types._ */ trait CheckAnalysis extends PredicateHelper { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + /** * Override to provide additional checks for correct analysis. * These rules will be evaluated after our built-in check rules. @@ -296,6 +298,21 @@ trait CheckAnalysis extends PredicateHelper { } } + case CreateTableAsSelect(_, _, partitioning, query, _, _, _) => + val references = partitioning.flatMap(_.references).toSet + val badReferences = references.map(_.fieldNames).flatMap { column => + query.schema.findNestedField(column) match { + case Some(_) => + None + case _ => + Some(s"${column.quoted} is missing or is in a map or array") + } + } + + if (badReferences.nonEmpty) { + failAnalysis(s"Invalid partitioning: ${badReferences.mkString(", ")}") + } + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala index ad201f947b671..56b8d84441c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala @@ -21,4 +21,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan trait NamedRelation extends LogicalPlan { def name: String + + // When false, the schema of input data must match the schema of this relation, during write. + def skipSchemaResolution: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 8bf6f69f3b17a..7ac8ae61ed537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ +import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -25,10 +27,24 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -class NoSuchDatabaseException(val db: String) extends AnalysisException(s"Database '$db' not found") +class NoSuchDatabaseException( + val db: String) extends NoSuchNamespaceException(s"Database '$db' not found") -class NoSuchTableException(db: String, table: String) - extends AnalysisException(s"Table or view '$table' not found in database '$db'") +class NoSuchNamespaceException(message: String) extends AnalysisException(message) { + def this(namespace: Array[String]) = { + this(s"Namespace '${namespace.quoted}' not found") + } +} + +class NoSuchTableException(message: String) extends AnalysisException(message) { + def this(db: String, table: String) = { + this(s"Table or view '$table' not found in database '$db'") + } + + def this(tableIdent: Identifier) = { + this(s"Table ${tableIdent.quoted} not found") + } +} class NoSuchPartitionException( db: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a27c6d3c3671c..81ec2a1d9c904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -29,14 +29,18 @@ import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2 +import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -86,6 +90,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging visitFunctionIdentifier(ctx.functionIdentifier) } + override def visitSingleMultipartIdentifier( + ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + visitMultipartIdentifier(ctx.multipartIdentifier) + } + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { visitSparkDataType(ctx.dataType) } @@ -117,6 +126,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + override def visitQueryToDesc(ctx: QueryToDescContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + } + /** * Create a named logical plan. * @@ -953,6 +966,14 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) } + /** + * Create a multi-part identifier. + */ + override def visitMultipartIdentifier( + ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + ctx.parts.asScala.map(_.getText) + } + /* ******************************************************************************************** * Expression parsing * ******************************************************************************************** */ @@ -1851,4 +1872,301 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true) if (STRING == null) structField else structField.withComment(string(STRING)) } + + /** + * Create location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create a [[BucketSpec]]. + */ + override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { + BucketSpec( + ctx.INTEGER_VALUE.getText.toInt, + visitIdentifierList(ctx.identifierList), + Option(ctx.orderedIdentifierList) + .toSeq + .flatMap(_.orderedIdentifier.asScala) + .map { orderedIdCtx => + Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => + if (dir.toLowerCase(Locale.ROOT) != "asc") { + operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) + } + } + + orderedIdCtx.identifier.getText + }) + } + + /** + * Convert a table property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + val properties = ctx.tableProperty.asScala.map { property => + val key = visitTablePropertyKey(property.key) + val value = visitTablePropertyValue(property.value) + key -> value + } + // Check for duplicate property names. + checkDuplicateKeys(properties, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. + */ + def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.collect { case (key, null) => key } + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. + */ + def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq + } + + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * A table property value can be String, Integer, Boolean or Decimal. This function extracts + * the property value based on whether its a string, integer, boolean or decimal literal. + */ + override def visitTablePropertyValue(value: TablePropertyValueContext): String = { + if (value == null) { + null + } else if (value.STRING != null) { + string(value.STRING) + } else if (value.booleanValue != null) { + value.getText.toLowerCase(Locale.ROOT) + } else { + value.getText + } + } + + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ + type TableHeader = (Seq[String], Boolean, Boolean, Boolean) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + if (temporary && ifNotExists) { + operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) + } + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText) + (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Parse a list of transforms. + */ + override def visitTransformList(ctx: TransformListContext): Seq[Transform] = withOrigin(ctx) { + def getFieldReference( + ctx: ApplyTransformContext, + arg: v2.expressions.Expression): FieldReference = { + lazy val name: String = ctx.identifier.getText + arg match { + case ref: FieldReference => + ref + case nonRef => + throw new ParseException( + s"Expected a column reference for transform $name: ${nonRef.describe}", ctx) + } + } + + def getSingleFieldReference( + ctx: ApplyTransformContext, + arguments: Seq[v2.expressions.Expression]): FieldReference = { + lazy val name: String = ctx.identifier.getText + if (arguments.size > 1) { + throw new ParseException(s"Too many arguments for transform $name", ctx) + } else if (arguments.isEmpty) { + throw new ParseException(s"Not enough arguments for transform $name", ctx) + } else { + getFieldReference(ctx, arguments.head) + } + } + + ctx.transforms.asScala.map { + case identityCtx: IdentityTransformContext => + IdentityTransform(FieldReference( + identityCtx.qualifiedName.identifier.asScala.map(_.getText))) + + case applyCtx: ApplyTransformContext => + val arguments = applyCtx.argument.asScala.map(visitTransformArgument) + + applyCtx.identifier.getText match { + case "bucket" => + val numBuckets: Int = arguments.head match { + case LiteralValue(shortValue, ShortType) => + shortValue.asInstanceOf[Short].toInt + case LiteralValue(intValue, IntegerType) => + intValue.asInstanceOf[Int] + case LiteralValue(longValue, LongType) => + longValue.asInstanceOf[Long].toInt + case lit => + throw new ParseException(s"Invalid number of buckets: ${lit.describe}", applyCtx) + } + + val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg)) + + BucketTransform(LiteralValue(numBuckets, IntegerType), fields) + + case "years" => + YearsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "months" => + MonthsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "days" => + DaysTransform(getSingleFieldReference(applyCtx, arguments)) + + case "hours" => + HoursTransform(getSingleFieldReference(applyCtx, arguments)) + + case name => + ApplyTransform(name, arguments) + } + } + } + + /** + * Parse an argument to a transform. An argument may be a field reference (qualified name) or + * a value literal. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): v2.expressions.Expression = { + withOrigin(ctx) { + val reference = Option(ctx.qualifiedName) + .map(nameCtx => FieldReference(nameCtx.identifier.asScala.map(_.getText))) + val literal = Option(ctx.constant) + .map(typedVisit[Literal]) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw new ParseException(s"Invalid transform argument", ctx)) + } + } + + /** + * Create a table, returning a [[CreateTableStatement]] logical plan. + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name + * USING table_provider + * create_table_clauses + * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, transform(col_name), transform(constant, col_name), ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] + * }}} + */ + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + if (external) { + operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) + } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + + val schema = Option(ctx.colTypeList()).map(createSchema) + val partitioning: Seq[Transform] = + Option(ctx.partitioning).map(visitTransformList).getOrElse(Nil) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + + val provider = ctx.tableProvider.qualifiedName.getText + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) + val comment = Option(ctx.comment).map(string) + + Option(ctx.query).map(plan) match { + case Some(_) if temp => + operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) + + case Some(_) if schema.isDefined => + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + + case Some(query) => + CreateTableAsSelectStatement( + table, query, partitioning, bucketSpec, properties, provider, options, location, comment, + ifNotExists = ifNotExists) + + case None if temp => + // CREATE TEMPORARY TABLE ... USING ... is not supported by the catalyst parser. + // Use CREATE TEMPORARY VIEW ... USING ... instead. + operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) + + case _ => + CreateTableStatement(table, schema.getOrElse(new StructType), partitioning, bucketSpec, + properties, provider, options, location, comment, ifNotExists = ifNotExists) + } + } + + /** + * Create a [[DropTableStatement]] command. + */ + override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { + DropTableStatement( + visitMultipartIdentifier(ctx.multipartIdentifier()), + ctx.EXISTS != null, + ctx.PURGE != null) + } + + /** + * Create a [[DropViewStatement]] command. + */ + override def visitDropView(ctx: DropViewContext): AnyRef = withOrigin(ctx) { + DropViewStatement( + visitMultipartIdentifier(ctx.multipartIdentifier()), + ctx.EXISTS != null) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 2128a10d0b1bc..31917ab9a5579 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -57,6 +57,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { } } + /** Creates a multi-part identifier for a given SQL string */ + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + parse(sqlText) { parser => + astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier()) + } + } + /** * Creates StructType for a given SQL string, which is a comma separated list of field * definitions which will preserve the correct Hive metadata. @@ -85,6 +92,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced + lexer.ansi = SQLConf.get.ansiParserEnabled val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) @@ -92,6 +100,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced + parser.ansi = SQLConf.get.ansiParserEnabled try { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index 75240d2196222..77e357ad073da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -52,6 +52,12 @@ trait ParserInterface { @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier") def parseFunctionIdentifier(sqlText: String): FunctionIdentifier + /** + * Parse a string to a multi-part identifier. + */ + @throws[ParseException]("Text cannot be parsed to a multi-part identifier") + def parseMultipartIdentifier(sqlText: String): Seq[String] + /** * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list * of field definitions which will preserve the correct Hive metadata. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 639d68f4ecd76..256d3261055e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} @@ -365,37 +367,132 @@ case class Join( } /** - * Append data to an existing table. + * Base trait for DataSourceV2 write commands */ -case class AppendData( - table: NamedRelation, - query: LogicalPlan, - isByName: Boolean) extends LogicalPlan { +trait V2WriteCommand extends Command { + def table: NamedRelation + def query: LogicalPlan + override def children: Seq[LogicalPlan] = Seq(query) - override def output: Seq[Attribute] = Seq.empty - override lazy val resolved: Boolean = { - table.resolved && query.resolved && query.output.size == table.output.size && + override lazy val resolved: Boolean = outputResolved + + def outputResolved: Boolean = { + // If the table doesn't require schema match, we don't need to resolve the output columns. + table.skipSchemaResolution || { + table.resolved && query.resolved && query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => // names and types must match, nullability must be compatible inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && - (outAttr.nullable || !inAttr.nullable) + DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && + (outAttr.nullable || !inAttr.nullable) } + } } } +/** + * Create a new table with a v2 catalog. + */ +case class CreateV2Table( + catalog: TableCatalog, + tableName: Identifier, + tableSchema: StructType, + partitioning: Seq[Transform], + properties: Map[String, String], + ignoreIfExists: Boolean) extends Command + +/** + * Create a new table from a select query with a v2 catalog. + */ +case class CreateTableAsSelect( + catalog: TableCatalog, + tableName: Identifier, + partitioning: Seq[Transform], + query: LogicalPlan, + properties: Map[String, String], + writeOptions: Map[String, String], + ignoreIfExists: Boolean) extends Command { + + override def children: Seq[LogicalPlan] = Seq(query) + + override lazy val resolved: Boolean = { + // the table schema is created from the query schema, so the only resolution needed is to check + // that the columns referenced by the table's partitioning exist in the query schema + val references = partitioning.flatMap(_.references).toSet + references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined) + } +} + +/** + * Append data to an existing table. + */ +case class AppendData( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand + object AppendData { def byName(table: NamedRelation, df: LogicalPlan): AppendData = { - new AppendData(table, df, true) + new AppendData(table, df, isByName = true) } def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { - new AppendData(table, query, false) + new AppendData(table, query, isByName = false) + } +} + +/** + * Overwrite data matching a filter in an existing table. + */ +case class OverwriteByExpression( + table: NamedRelation, + deleteExpr: Expression, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand { + override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved +} + +object OverwriteByExpression { + def byName( + table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, df, isByName = true) + } + + def byPosition( + table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, query, isByName = false) } } +/** + * Dynamically overwrite partitions in an existing table. + */ +case class OverwritePartitionsDynamic( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand + +object OverwritePartitionsDynamic { + def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, df, isByName = true) + } + + def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, query, isByName = false) + } +} + +/** + * Drop a table. + */ +case class DropTable( + catalog: TableCatalog, + ident: Identifier, + ifExists: Boolean) extends Command + + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala new file mode 100644 index 0000000000000..7a26e01cde830 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.StructType + +/** + * A CREATE TABLE command, as parsed from SQL. + * + * This is a metadata-only command and is not used to write data to the created table. + */ +case class CreateTableStatement( + tableName: Seq[String], + tableSchema: StructType, + partitioning: Seq[Transform], + bucketSpec: Option[BucketSpec], + properties: Map[String, String], + provider: String, + options: Map[String, String], + location: Option[String], + comment: Option[String], + ifNotExists: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty +} + +/** + * A CREATE TABLE AS SELECT command, as parsed from SQL. + */ +case class CreateTableAsSelectStatement( + tableName: Seq[String], + asSelect: LogicalPlan, + partitioning: Seq[Transform], + bucketSpec: Option[BucketSpec], + properties: Map[String, String], + provider: String, + options: Map[String, String], + location: Option[String], + comment: Option[String], + ifNotExists: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq(asSelect) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala new file mode 100644 index 0000000000000..d41e8a5010257 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * A DROP TABLE statement, as parsed from SQL. + */ +case class DropTableStatement( + tableName: Seq[String], + ifExists: Boolean, + purge: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala new file mode 100644 index 0000000000000..523158788e834 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * A DROP VIEW statement, as parsed from SQL. + */ +case class DropViewStatement( + viewName: Seq[String], + ifExists: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala new file mode 100644 index 0000000000000..510f2a1ba1e6d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * A logical plan node that contains exactly what was parsed from SQL. + * + * This is used to hold information parsed from SQL when there are multiple implementations of a + * query or command. For example, CREATE TABLE may be implemented by different nodes for v1 and v2. + * Instead of parsing directly to a v1 CreateTable that keeps metadata in CatalogTable, and then + * converting that v1 metadata to the v2 equivalent, the sql [[CreateTableStatement]] plan is + * produced by the parser and converted once into both implementations. + * + * Parsed logical plans are not resolved because they must be converted to concrete logical plans. + * + * Parsed logical plans are located in Catalyst so that as much SQL parsing logic as possible is be + * kept in a [[org.apache.spark.sql.catalyst.parser.AbstractSqlParser]]. + */ +private[sql] abstract class ParsedStatement extends LogicalPlan { + // Redact properties and options when parsed nodes are used by generic methods like toString + override def productIterator: Iterator[Any] = super.productIterator.map { + case mapArg: Map[_, _] => conf.redactOptions(mapArg.asInstanceOf[Map[String, String]]) + case other => other + } + + final override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f590c63f80b21..a85cad35ac6fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} +import java.time.ZoneId import java.util.{Calendar, Locale, TimeZone} import java.util.concurrent.ConcurrentHashMap import java.util.function.{Function => JFunction} @@ -123,6 +124,8 @@ object DateTimeUtils { override def apply(timeZoneId: String): TimeZone = TimeZone.getTimeZone(timeZoneId) } + def getZoneId(timeZoneId: String): ZoneId = ZoneId.of(timeZoneId, ZoneId.SHORT_IDS) + def getTimeZone(timeZoneId: String): TimeZone = { computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 448338f61346f..cbc57066163c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -314,6 +314,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ANSI_SQL_PARSER = + buildConf("spark.sql.parser.ansi.enabled") + .doc("When true, tries to conform to ANSI SQL syntax.") + .booleanConf + .createWithDefault(false) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -918,6 +924,12 @@ object SQLConf { .stringConf .createOptional + val FORCE_DELETE_TEMP_CHECKPOINT_LOCATION = + buildConf("spark.sql.streaming.forceDeleteTempCheckpointLocation") + .doc("When true, enable temporary checkpoint locations force delete.") + .booleanConf + .createWithDefault(false) + val MIN_BATCHES_TO_RETAIN = buildConf("spark.sql.streaming.minBatchesToRetain") .internal() .doc("The minimum number of batches that must be retained and made recoverable.") @@ -1117,6 +1129,14 @@ object SQLConf { .internal() .stringConf + val STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED = + buildConf("spark.sql.streaming.checkpoint.escapedPathCheck.enabled") + .doc("Whether to detect a streaming query may pick up an incorrect checkpoint path due " + + "to SPARK-26824.") + .internal() + .booleanConf + .createWithDefault(true) + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled") .internal() @@ -1427,6 +1447,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE = + buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize") + .doc("The max number of entries to be stored in queue to wait for late epochs. " + + "If this parameter is exceeded by the size of the queue, stream will stop with an error.") + .intConf + .createWithDefault(10000) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1457,7 +1484,7 @@ object SQLConf { " register class names for which data source V2 write paths are disabled. Writes from these" + " sources will fall back to the V1 sources.") .stringConf - .createWithDefault("") + .createWithDefault("orc") val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + @@ -1673,6 +1700,11 @@ object SQLConf { "a SparkConf entry.") .booleanConf .createWithDefault(true) + + val DEFAULT_V2_CATALOG = buildConf("spark.sql.default.catalog") + .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") + .stringConf + .createOptional } /** @@ -1848,6 +1880,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def ansiParserEnabled: Boolean = getConf(ANSI_SQL_PARSER) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) @@ -2059,14 +2093,17 @@ class SQLConf extends Serializable with Logging { def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) + def continuousStreamingEpochBacklogQueueSize: Int = + getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) - def userV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST) + def useV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST) - def userV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST) + def useV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST) def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) @@ -2118,6 +2155,8 @@ class SQLConf extends Serializable with Logging { def setCommandRejectsSparkCoreConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) + def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 3f941cc6e1072..a1ab55a7185ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.Stable +import org.apache.spark.annotation.{Evolving, Stable} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -218,3 +218,27 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } + +/** + * A filter that always evaluates to `true`. + */ +@Evolving +case class AlwaysTrue() extends Filter { + override def references: Array[String] = Array.empty +} + +@Evolving +object AlwaysTrue extends AlwaysTrue { +} + +/** + * A filter that always evaluates to `false`. + */ +@Evolving +case class AlwaysFalse() extends Filter { + override def references: Array[String] = Array.empty +} + +@Evolving +object AlwaysFalse extends AlwaysFalse { +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d563276a5711d..c472bd8ee84b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -307,6 +307,29 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru nameToIndex.get(name) } + /** + * Returns a field in this struct and its child structs. + * + * This does not support finding fields nested in maps or arrays. + */ + private[sql] def findNestedField(fieldNames: Seq[String]): Option[StructField] = { + fieldNames.headOption.flatMap(nameToField.get) match { + case Some(field) => + if (fieldNames.tail.isEmpty) { + Some(field) + } else { + field.dataType match { + case struct: StructType => + struct.findNestedField(fieldNames.tail) + case _ => + None + } + } + case _ => + None + } + } + protected[sql] def toAttributes: Seq[AttributeReference] = map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 7de6256aef084..62546a322d3c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.arrow +package org.apache.spark.sql.util import scala.collection.JavaConverters._ diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java new file mode 100644 index 0000000000000..326b12f3618d3 --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.SparkException; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.Callable; + +public class CatalogLoadingSuite { + @Test + public void testLoad() throws SparkException { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.test-name", TestCatalogPlugin.class.getCanonicalName()); + + CatalogPlugin plugin = Catalogs.load("test-name", conf); + Assert.assertNotNull("Should instantiate a non-null plugin", plugin); + Assert.assertEquals("Plugin should have correct implementation", + TestCatalogPlugin.class, plugin.getClass()); + + TestCatalogPlugin testPlugin = (TestCatalogPlugin) plugin; + Assert.assertEquals("Options should contain no keys", 0, testPlugin.options.size()); + Assert.assertEquals("Catalog should have correct name", "test-name", testPlugin.name()); + } + + @Test + public void testInitializationOptions() throws SparkException { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.test-name", TestCatalogPlugin.class.getCanonicalName()); + conf.setConfString("spark.sql.catalog.test-name.name", "not-catalog-name"); + conf.setConfString("spark.sql.catalog.test-name.kEy", "valUE"); + + CatalogPlugin plugin = Catalogs.load("test-name", conf); + Assert.assertNotNull("Should instantiate a non-null plugin", plugin); + Assert.assertEquals("Plugin should have correct implementation", + TestCatalogPlugin.class, plugin.getClass()); + + TestCatalogPlugin testPlugin = (TestCatalogPlugin) plugin; + + Assert.assertEquals("Options should contain only two keys", 2, testPlugin.options.size()); + Assert.assertEquals("Options should contain correct value for name (not overwritten)", + "not-catalog-name", testPlugin.options.get("name")); + Assert.assertEquals("Options should contain correct value for key", + "valUE", testPlugin.options.get("key")); + } + + @Test + public void testLoadWithoutConfig() { + SQLConf conf = new SQLConf(); + + SparkException exc = intercept(CatalogNotFoundException.class, + () -> Catalogs.load("missing", conf)); + + Assert.assertTrue("Should complain that implementation is not configured", + exc.getMessage() + .contains("plugin class not found: spark.sql.catalog.missing is not defined")); + Assert.assertTrue("Should identify the catalog by name", + exc.getMessage().contains("missing")); + } + + @Test + public void testLoadMissingClass() { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.missing", "com.example.NoSuchCatalogPlugin"); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("missing", conf)); + + Assert.assertTrue("Should complain that the class is not found", + exc.getMessage().contains("Cannot find catalog plugin class")); + Assert.assertTrue("Should identify the catalog by name", + exc.getMessage().contains("missing")); + Assert.assertTrue("Should identify the missing class", + exc.getMessage().contains("com.example.NoSuchCatalogPlugin")); + } + + @Test + public void testLoadNonCatalogPlugin() { + SQLConf conf = new SQLConf(); + String invalidClassName = InvalidCatalogPlugin.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should complain that class does not implement CatalogPlugin", + exc.getMessage().contains("does not implement CatalogPlugin")); + Assert.assertTrue("Should identify the catalog by name", + exc.getMessage().contains("invalid")); + Assert.assertTrue("Should identify the class", + exc.getMessage().contains(invalidClassName)); + } + + @Test + public void testLoadConstructorFailureCatalogPlugin() { + SQLConf conf = new SQLConf(); + String invalidClassName = ConstructorFailureCatalogPlugin.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + RuntimeException exc = intercept(RuntimeException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should have expected error message", + exc.getMessage().contains("Expected failure")); + } + + @Test + public void testLoadAccessErrorCatalogPlugin() { + SQLConf conf = new SQLConf(); + String invalidClassName = AccessErrorCatalogPlugin.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should complain that no public constructor is provided", + exc.getMessage().contains("Failed to call public no-arg constructor for catalog")); + Assert.assertTrue("Should identify the catalog by name", + exc.getMessage().contains("invalid")); + Assert.assertTrue("Should identify the class", + exc.getMessage().contains(invalidClassName)); + } + + @SuppressWarnings("unchecked") + public static E intercept(Class expected, Callable callable) { + try { + callable.call(); + Assert.fail("No exception was thrown, expected: " + + expected.getName()); + } catch (Exception actual) { + try { + Assert.assertEquals(expected, actual.getClass()); + return (E) actual; + } catch (AssertionError e) { + e.addSuppressed(actual); + throw e; + } + } + // Compiler doesn't catch that Assert.fail will always throw an exception. + throw new UnsupportedOperationException("[BUG] Should not reach this statement"); + } +} + +class TestCatalogPlugin implements CatalogPlugin { + String name = null; + CaseInsensitiveStringMap options = null; + + TestCatalogPlugin() { + } + + @Override + public void initialize(String name, CaseInsensitiveStringMap options) { + this.name = name; + this.options = options; + } + + @Override + public String name() { + return name; + } +} + +class ConstructorFailureCatalogPlugin implements CatalogPlugin { // fails in its constructor + ConstructorFailureCatalogPlugin() { + throw new RuntimeException("Expected failure."); + } + + @Override + public void initialize(String name, CaseInsensitiveStringMap options) { + } + + @Override + public String name() { + return null; + } +} + +class AccessErrorCatalogPlugin implements CatalogPlugin { // no public constructor + private AccessErrorCatalogPlugin() { + } + + @Override + public void initialize(String name, CaseInsensitiveStringMap options) { + } + + @Override + public String name() { + return null; + } +} + +class InvalidCatalogPlugin { // doesn't implement CatalogPlugin + public void initialize(CaseInsensitiveStringMap options) { + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala new file mode 100644 index 0000000000000..9c1b9a3e53de2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala @@ -0,0 +1,657 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import java.util +import java.util.Collections + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class TableCatalogSuite extends SparkFunSuite { + import CatalogV2Implicits._ + + private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val schema: StructType = new StructType() + .add("id", IntegerType) + .add("data", StringType) + + private def newCatalog(): TableCatalog = { + val newCatalog = new TestTableCatalog + newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) + newCatalog + } + + private val testIdent = Identifier.of(Array("`", "."), "test_table") + + test("Catalogs can load the catalog") { + val catalog = newCatalog() + + val conf = new SQLConf + conf.setConfString("spark.sql.catalog.test", catalog.getClass.getName) + + val loaded = Catalogs.load("test", conf) + assert(loaded.getClass == catalog.getClass) + } + + test("listTables") { + val catalog = newCatalog() + val ident1 = Identifier.of(Array("ns"), "test_table_1") + val ident2 = Identifier.of(Array("ns"), "test_table_2") + val ident3 = Identifier.of(Array("ns2"), "test_table_1") + + assert(catalog.listTables(Array("ns")).isEmpty) + + catalog.createTable(ident1, schema, Array.empty, emptyProps) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident1)) + assert(catalog.listTables(Array("ns2")).isEmpty) + + catalog.createTable(ident3, schema, Array.empty, emptyProps) + catalog.createTable(ident2, schema, Array.empty, emptyProps) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident1, ident2)) + assert(catalog.listTables(Array("ns2")).toSet == Set(ident3)) + + catalog.dropTable(ident1) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident2)) + + catalog.dropTable(ident2) + + assert(catalog.listTables(Array("ns")).isEmpty) + assert(catalog.listTables(Array("ns2")).toSet == Set(ident3)) + } + + test("createTable") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) + assert(parsed == Seq("`", ".", "test_table")) + assert(table.schema == schema) + assert(table.properties.asScala == Map()) + + assert(catalog.tableExists(testIdent)) + } + + test("createTable: with properties") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("property", "value") + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) + assert(parsed == Seq("`", ".", "test_table")) + assert(table.schema == schema) + assert(table.properties == properties) + + assert(catalog.tableExists(testIdent)) + } + + test("createTable: table already exists") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + val exc = intercept[TableAlreadyExistsException] { + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + } + + assert(exc.message.contains(table.name())) + assert(exc.message.contains("already exists")) + + assert(catalog.tableExists(testIdent)) + } + + test("tableExists") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(catalog.tableExists(testIdent)) + + catalog.dropTable(testIdent) + + assert(!catalog.tableExists(testIdent)) + } + + test("loadTable") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + val loaded = catalog.loadTable(testIdent) + + assert(table.name == loaded.name) + assert(table.schema == loaded.schema) + assert(table.properties == loaded.properties) + } + + test("loadTable: table does not exist") { + val catalog = newCatalog() + + val exc = intercept[NoSuchTableException] { + catalog.loadTable(testIdent) + } + + assert(exc.message.contains(testIdent.quoted)) + assert(exc.message.contains("not found")) + } + + test("invalidateTable") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + catalog.invalidateTable(testIdent) + + val loaded = catalog.loadTable(testIdent) + + assert(table.name == loaded.name) + assert(table.schema == loaded.schema) + assert(table.properties == loaded.properties) + } + + test("invalidateTable: table does not exist") { + val catalog = newCatalog() + + assert(catalog.tableExists(testIdent) === false) + + catalog.invalidateTable(testIdent) + } + + test("alterTable: add property") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.properties.asScala == Map()) + + val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-1", "1")) + assert(updated.properties.asScala == Map("prop-1" -> "1")) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map("prop-1" -> "1")) + + assert(table.properties.asScala == Map()) + } + + test("alterTable: add property to existing") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("prop-1", "1") + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + + val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-2", "2")) + assert(updated.properties.asScala == Map("prop-1" -> "1", "prop-2" -> "2")) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map("prop-1" -> "1", "prop-2" -> "2")) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + } + + test("alterTable: remove existing property") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("prop-1", "1") + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + + val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + assert(updated.properties.asScala == Map()) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map()) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + } + + test("alterTable: remove missing property") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.properties.asScala == Map()) + + val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + assert(updated.properties.asScala == Map()) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map()) + + assert(table.properties.asScala == Map()) + } + + test("alterTable: add top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType)) + + assert(updated.schema == schema.add("ts", TimestampType)) + } + + test("alterTable: add required column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("ts"), TimestampType, false)) + + assert(updated.schema == schema.add("ts", TimestampType, nullable = false)) + } + + test("alterTable: add column with comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("ts"), TimestampType, false, "comment text")) + + val field = StructField("ts", TimestampType, nullable = false).withComment("comment text") + assert(updated.schema == schema.add(field)) + } + + test("alterTable: add nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("point", "z"), DoubleType)) + + val expectedSchema = schema.add("point", pointStruct.add("z", DoubleType)) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: add column to primitive field fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.addColumn(Array("data", "ts"), TimestampType)) + } + + assert(exc.getMessage.contains("Not a struct")) + assert(exc.getMessage.contains("data")) + + // the table has not changed + assert(catalog.loadTable(testIdent).schema == schema) + } + + test("alterTable: add field to missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.addColumn(Array("missing_col", "new_field"), StringType)) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: update column data type") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType)) + + val expectedSchema = new StructType().add("id", LongType).add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: update column data type and nullability") { + val catalog = newCatalog() + + val originalSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("data", StringType) + val table = catalog.createTable(testIdent, originalSchema, Array.empty, emptyProps) + + assert(table.schema == originalSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnType(Array("id"), LongType, true)) + + val expectedSchema = new StructType().add("id", LongType).add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: update optional column to required fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType, false)) + } + + assert(exc.getMessage.contains("Cannot change optional column to required")) + assert(exc.getMessage.contains("id")) + } + + test("alterTable: update missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.updateColumnType(Array("missing_col"), LongType)) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: add comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("id"), "comment text")) + + val expectedSchema = new StructType() + .add("id", IntegerType, nullable = true, "comment text") + .add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: replace comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "comment text")) + + val expectedSchema = new StructType() + .add("id", IntegerType, nullable = true, "replacement comment") + .add("data", StringType) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("id"), "replacement comment")) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: add comment to missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("missing_col"), "comment")) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: rename top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("id"), "some_id")) + + val expectedSchema = new StructType().add("some_id", IntegerType).add("data", StringType) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point", "x"), "first")) + + val newPointStruct = new StructType().add("first", DoubleType).add("y", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename struct column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point"), "p")) + + val newPointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val expectedSchema = schema.add("p", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.renameColumn(Array("missing_col"), "new_name")) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: multiple changes") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point", "x"), "first"), + TableChange.renameColumn(Array("point", "y"), "second")) + + val newPointStruct = new StructType().add("first", DoubleType).add("second", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.deleteColumn(Array("id"))) + + val expectedSchema = new StructType().add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.deleteColumn(Array("point", "y"))) + + val newPointStruct = new StructType().add("x", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"))) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: delete missing nested column fails") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"))) + } + + assert(exc.getMessage.contains("z")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: table does not exist") { + val catalog = newCatalog() + + val exc = intercept[NoSuchTableException] { + catalog.alterTable(testIdent, TableChange.setProperty("prop", "val")) + } + + assert(exc.message.contains(testIdent.quoted)) + assert(exc.message.contains("not found")) + } + + test("dropTable") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(catalog.tableExists(testIdent)) + + val wasDropped = catalog.dropTable(testIdent) + + assert(wasDropped) + assert(!catalog.tableExists(testIdent)) + } + + test("dropTable: table does not exist") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val wasDropped = catalog.dropTable(testIdent) + + assert(!wasDropped) + assert(!catalog.tableExists(testIdent)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala new file mode 100644 index 0000000000000..78b4763484cc0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import java.util +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalog.v2.TableChange.{AddColumn, DeleteColumn, RemoveProperty, RenameColumn, SetProperty, UpdateColumnComment, UpdateColumnType} +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.sources.v2.{Table, TableCapability} +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class TestTableCatalog extends TableCatalog { + import CatalogV2Implicits._ + + private val tables: util.Map[Identifier, Table] = new ConcurrentHashMap[Identifier, Table]() + private var _name: Option[String] = None + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + _name = Some(name) + } + + override def name: String = _name.get + + override def listTables(namespace: Array[String]): Array[Identifier] = { + tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray + } + + override def loadTable(ident: Identifier): Table = { + Option(tables.get(ident)) match { + case Some(table) => + table + case _ => + throw new NoSuchTableException(ident) + } + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + if (partitions.nonEmpty) { + throw new UnsupportedOperationException( + s"Catalog $name: Partitioned tables are not supported") + } + + val table = InMemoryTable(ident.quoted, schema, properties) + + tables.put(ident, table) + + table + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + val table = loadTable(ident) + val properties = TestTableCatalog.applyPropertiesChanges(table.properties, changes) + val schema = TestTableCatalog.applySchemaChanges(table.schema, changes) + val newTable = InMemoryTable(table.name, schema, properties) + + tables.put(ident, newTable) + + newTable + } + + override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined +} + +object TestTableCatalog { + /** + * Apply properties changes to a map and return the result. + */ + def applyPropertiesChanges( + properties: util.Map[String, String], + changes: Seq[TableChange]): util.Map[String, String] = { + val newProperties = new util.HashMap[String, String](properties) + + changes.foreach { + case set: SetProperty => + newProperties.put(set.property, set.value) + + case unset: RemoveProperty => + newProperties.remove(unset.property) + + case _ => + // ignore non-property changes + } + + Collections.unmodifiableMap(newProperties) + } + + /** + * Apply schema changes to a schema and return the result. + */ + def applySchemaChanges(schema: StructType, changes: Seq[TableChange]): StructType = { + changes.foldLeft(schema) { (schema, change) => + change match { + case add: AddColumn => + add.fieldNames match { + case Array(name) => + val newField = StructField(name, add.dataType, nullable = add.isNullable) + Option(add.comment) match { + case Some(comment) => + schema.add(newField.withComment(comment)) + case _ => + schema.add(newField) + } + + case names => + replace(schema, names.init, parent => parent.dataType match { + case parentType: StructType => + val field = StructField(names.last, add.dataType, nullable = add.isNullable) + val newParentType = Option(add.comment) match { + case Some(comment) => + parentType.add(field.withComment(comment)) + case None => + parentType.add(field) + } + + Some(StructField(parent.name, newParentType, parent.nullable, parent.metadata)) + + case _ => + throw new IllegalArgumentException(s"Not a struct: ${names.init.last}") + }) + } + + case rename: RenameColumn => + replace(schema, rename.fieldNames, field => + Some(StructField(rename.newName, field.dataType, field.nullable, field.metadata))) + + case update: UpdateColumnType => + replace(schema, update.fieldNames, field => { + if (!update.isNullable && field.nullable) { + throw new IllegalArgumentException( + s"Cannot change optional column to required: $field.name") + } + Some(StructField(field.name, update.newDataType, update.isNullable, field.metadata)) + }) + + case update: UpdateColumnComment => + replace(schema, update.fieldNames, field => + Some(field.withComment(update.newComment))) + + case delete: DeleteColumn => + replace(schema, delete.fieldNames, _ => None) + + case _ => + // ignore non-schema changes + schema + } + } + } + + private def replace( + struct: StructType, + path: Seq[String], + update: StructField => Option[StructField]): StructType = { + + val pos = struct.getFieldIndex(path.head) + .getOrElse(throw new IllegalArgumentException(s"Cannot find field: ${path.head}")) + val field = struct.fields(pos) + val replacement: Option[StructField] = if (path.tail.isEmpty) { + update(field) + } else { + field.dataType match { + case nestedStruct: StructType => + val updatedType: StructType = replace(nestedStruct, path.tail, update) + Some(StructField(field.name, updatedType, field.nullable, field.metadata)) + case _ => + throw new IllegalArgumentException(s"Not a struct: ${path.head}") + } + } + + val newFields = struct.fields.zipWithIndex.flatMap { + case (_, index) if pos == index => + replacement + case (other, _) => + Some(other) + } + + new StructType(newFields) + } +} + +case class InMemoryTable( + name: String, + schema: StructType, + override val properties: util.Map[String, String]) extends Table { + override def partitioning: Array[Transform] = Array.empty + override def capabilities: util.Set[TableCapability] = InMemoryTable.CAPABILITIES +} + +object InMemoryTable { + val CAPABILITIES: util.Set[TableCapability] = Set.empty[TableCapability].asJava +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala new file mode 100644 index 0000000000000..1ce8852f71bc8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog, TestTableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.LogicalExpressions +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode} +import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class CreateTablePartitioningValidationSuite extends AnalysisTest { + import CreateTablePartitioningValidationSuite._ + + test("CreateTableAsSelect: fail missing top-level column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "does_not_exist") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assert(!plan.resolved) + assertAnalysisError(plan, Seq( + "Invalid partitioning", + "does_not_exist is missing or is in a map or array")) + } + + test("CreateTableAsSelect: fail missing top-level column nested reference") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "does_not_exist.z") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assert(!plan.resolved) + assertAnalysisError(plan, Seq( + "Invalid partitioning", + "does_not_exist.z is missing or is in a map or array")) + } + + test("CreateTableAsSelect: fail missing nested column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "point.z") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assert(!plan.resolved) + assertAnalysisError(plan, Seq( + "Invalid partitioning", + "point.z is missing or is in a map or array")) + } + + test("CreateTableAsSelect: fail with multiple errors") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "does_not_exist", "point.z") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assert(!plan.resolved) + assertAnalysisError(plan, Seq( + "Invalid partitioning", + "point.z is missing or is in a map or array", + "does_not_exist is missing or is in a map or array")) + } + + test("CreateTableAsSelect: success with top-level column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "id") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assertAnalysisSuccess(plan) + } + + test("CreateTableAsSelect: success using nested column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "point.x") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assertAnalysisSuccess(plan) + } + + test("CreateTableAsSelect: success using complex column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "point") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assertAnalysisSuccess(plan) + } +} + +private object CreateTablePartitioningValidationSuite { + val catalog: TableCatalog = { + val cat = new TestTableCatalog() + cat.initialize("test", CaseInsensitiveStringMap.empty()) + cat + } + + val schema: StructType = new StructType() + .add("id", LongType) + .add("data", StringType) + .add("point", new StructType().add("x", DoubleType).add("y", DoubleType)) +} + +private case object TestRelation2 extends LeafNode with NamedRelation { + override def name: String = "source_relation" + override def output: Seq[AttributeReference] = + CreateTablePartitioningValidationSuite.schema.toAttributes +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index 6c899b610ac5b..48b43fcccacef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -19,15 +19,98 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, UpCast} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, Project} -import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ + +class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byPosition(table, query) + } +} + +class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byPosition(table, query) + } +} + +class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwriteByExpression.byName(table, query, Literal(true)) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwriteByExpression.byPosition(table, query, Literal(true)) + } + + test("delete expression is resolved using table fields") { + val table = TestRelation(StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("a", DoubleType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + val a = query.output.head + val b = query.output.last + val x = table.output.head + + val parsedPlan = OverwriteByExpression.byPosition(table, query, + LessThanOrEqual(UnresolvedAttribute(Seq("x")), Literal(15.0d))) + + val expectedPlan = OverwriteByExpression.byPosition(table, + Project(Seq( + Alias(Cast(a, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(b, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), + query), + LessThanOrEqual( + AttributeReference("x", DoubleType, nullable = false)(x.exprId), + Literal(15.0d))) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("delete expression is not resolved using query fields") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("a", DoubleType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + // the write is resolved (checked above). this test plan is not because of the expression. + val parsedPlan = OverwriteByExpression.byPosition(xRequiredTable, query, + LessThanOrEqual(UnresolvedAttribute(Seq("a")), Literal(15.0d))) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq("cannot resolve", "`a`", "given input columns", "x, y")) + } +} case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { override def name: String = "table-name" } -class DataSourceV2AnalysisSuite extends AnalysisTest { +case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) + extends LeafNode with NamedRelation { + override def name: String = "test-name" + override def skipSchemaResolution: Boolean = true +} + +abstract class DataSourceV2AnalysisSuite extends AnalysisTest { val table = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType))).toAttributes) @@ -40,21 +123,25 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("y", DoubleType))).toAttributes) - test("Append.byName: basic behavior") { + def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan + + def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan + + test("byName: basic behavior") { val query = TestRelation(table.schema.toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) checkAnalysis(parsedPlan, parsedPlan) assertResolved(parsedPlan) } - test("Append.byName: does not match by position") { + test("byName: does not match by position") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -62,12 +149,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'", "'y'")) } - test("Append.byName: case sensitive column resolution") { + test("byName: case sensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -76,7 +163,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { caseSensitive = true) } - test("Append.byName: case insensitive column resolution") { + test("byName: case insensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! StructField("y", FloatType))).toAttributes) @@ -84,8 +171,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val X = query.output.head val y = query.output.last - val parsedPlan = AppendData.byName(table, query) - val expectedPlan = AppendData.byName(table, + val parsedPlan = byName(table, query) + val expectedPlan = byName(table, Project(Seq( Alias(Cast(toLower(X), FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -96,7 +183,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: data columns are reordered by name") { + test("byName: data columns are reordered by name") { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), @@ -105,8 +192,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val y = query.output.head val x = query.output.last - val parsedPlan = AppendData.byName(table, query) - val expectedPlan = AppendData.byName(table, + val parsedPlan = byName(table, query) + val expectedPlan = byName(table, Project(Seq( Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -117,26 +204,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: fail nullable data written to required columns") { - val parsedPlan = AppendData.byName(requiredTable, table) + test("byName: fail nullable data written to required columns") { + val parsedPlan = byName(requiredTable, table) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", "'y'")) } - test("Append.byName: allow required data written to nullable columns") { - val parsedPlan = AppendData.byName(table, requiredTable) + test("byName: allow required data written to nullable columns") { + val parsedPlan = byName(table, requiredTable) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) } - test("Append.byName: missing required columns cause failure and are identified by name") { + test("byName: missing required columns cause failure and are identified by name") { // missing required field x val query = TestRelation(StructType(Seq( StructField("y", FloatType, nullable = false))).toAttributes) - val parsedPlan = AppendData.byName(requiredTable, query) + val parsedPlan = byName(requiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -144,12 +231,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("Append.byName: missing optional columns cause failure and are identified by name") { + test("byName: missing optional columns cause failure and are identified by name") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -157,8 +244,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("Append.byName: fail canWrite check") { - val parsedPlan = AppendData.byName(table, widerTable) + test("byName: fail canWrite check") { + val parsedPlan = byName(table, widerTable) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -166,12 +253,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) } - test("Append.byName: insert safe cast") { + test("byName: insert safe cast") { val x = table.output.head val y = table.output.last - val parsedPlan = AppendData.byName(widerTable, table) - val expectedPlan = AppendData.byName(widerTable, + val parsedPlan = byName(widerTable, table) + val expectedPlan = byName(widerTable, Project(Seq( Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -182,13 +269,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: fail extra data fields") { + test("byName: fail extra data fields") { val query = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType), StructField("z", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -197,7 +284,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } - test("Append.byName: multiple field errors are reported") { + test("byName: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", DoubleType))).toAttributes) @@ -206,7 +293,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(xRequiredTable, query) + val parsedPlan = byName(xRequiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -216,7 +303,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'y'")) } - test("Append.byPosition: basic behavior") { + test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType))).toAttributes) @@ -224,8 +311,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val a = query.output.head val b = query.output.last - val parsedPlan = AppendData.byPosition(table, query) - val expectedPlan = AppendData.byPosition(table, + val parsedPlan = byPosition(table, query) + val expectedPlan = byPosition(table, Project(Seq( Alias(Cast(a, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(b, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -236,7 +323,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: data columns are not reordered") { + test("byPosition: data columns are not reordered") { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), @@ -245,8 +332,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val y = query.output.head val x = query.output.last - val parsedPlan = AppendData.byPosition(table, query) - val expectedPlan = AppendData.byPosition(table, + val parsedPlan = byPosition(table, query) + val expectedPlan = byPosition(table, Project(Seq( Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -257,26 +344,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: fail nullable data written to required columns") { - val parsedPlan = AppendData.byPosition(requiredTable, table) + test("byPosition: fail nullable data written to required columns") { + val parsedPlan = byPosition(requiredTable, table) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", "'y'")) } - test("Append.byPosition: allow required data written to nullable columns") { - val parsedPlan = AppendData.byPosition(table, requiredTable) + test("byPosition: allow required data written to nullable columns") { + val parsedPlan = byPosition(table, requiredTable) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) } - test("Append.byPosition: missing required columns cause failure") { + test("byPosition: missing required columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType, nullable = false))).toAttributes) - val parsedPlan = AppendData.byPosition(requiredTable, query) + val parsedPlan = byPosition(requiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -285,12 +372,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("Append.byPosition: missing optional columns cause failure") { + test("byPosition: missing optional columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byPosition(table, query) + val parsedPlan = byPosition(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -299,12 +386,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("Append.byPosition: fail canWrite check") { + test("byPosition: fail canWrite check") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), StructField("b", DoubleType))).toAttributes) - val parsedPlan = AppendData.byPosition(table, widerTable) + val parsedPlan = byPosition(table, widerTable) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -312,7 +399,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) } - test("Append.byPosition: insert safe cast") { + test("byPosition: insert safe cast") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), StructField("b", DoubleType))).toAttributes) @@ -320,8 +407,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val x = table.output.head val y = table.output.last - val parsedPlan = AppendData.byPosition(widerTable, table) - val expectedPlan = AppendData.byPosition(widerTable, + val parsedPlan = byPosition(widerTable, table) + val expectedPlan = byPosition(widerTable, Project(Seq( Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "a")(), Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "b")()), @@ -332,13 +419,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: fail extra data fields") { + test("byPosition: fail extra data fields") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType), StructField("c", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -347,7 +434,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'a', 'b', 'c'")) } - test("Append.byPosition: multiple field errors are reported") { + test("byPosition: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", DoubleType))).toAttributes) @@ -356,7 +443,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byPosition(xRequiredTable, query) + val parsedPlan = byPosition(xRequiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -365,6 +452,27 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "DoubleType to FloatType")) } + test("bypass output column resolution") { + val table = TestRelationAcceptAnySchema(StructType(Seq( + StructField("a", FloatType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("s", StringType))).toAttributes) + + withClue("byName") { + val parsedPlan = byName(table, query) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + + withClue("byPosition") { + val parsedPlan = byPosition(table, query) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + } + def assertNotResolved(logicalPlan: LogicalPlan): Unit = { assert(!logicalPlan.resolved, s"Plan should not be resolved: $logicalPlan") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala new file mode 100644 index 0000000000000..783751ff79865 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog.v2 + +import org.scalatest.Inside +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +private case class TestCatalogPlugin(override val name: String) extends CatalogPlugin { + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit +} + +class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { + import CatalystSqlParser._ + + private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap + + override def lookupCatalog(name: String): CatalogPlugin = + catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) + + test("catalog object identifier") { + Seq( + ("tbl", None, Seq.empty, "tbl"), + ("db.tbl", None, Seq("db"), "tbl"), + ("prod.func", catalogs.get("prod"), Seq.empty, "func"), + ("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"), + ("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), + ("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"), + ("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"), + ("`db.tbl`", None, Seq.empty, "db.tbl"), + ("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None, + Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { + case (sql, expectedCatalog, namespace, name) => + inside(parseMultipartIdentifier(sql)) { + case CatalogObjectIdentifier(catalog, ident) => + catalog shouldEqual expectedCatalog + ident shouldEqual Identifier.of(namespace.toArray, name) + } + } + } + + test("table identifier") { + Seq( + ("tbl", "tbl", None), + ("db.tbl", "tbl", Some("db")), + ("`db.tbl`", "db.tbl", None), + ("parquet.`file:/tmp/db.tbl`", "file:/tmp/db.tbl", Some("parquet")), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", "s3://buck/tmp/abc.json", + Some("org.apache.spark.sql.json"))).foreach { + case (sql, table, db) => + inside (parseMultipartIdentifier(sql)) { + case AsTableIdentifier(ident) => + ident shouldEqual TableIdentifier(table, db) + } + } + Seq( + "prod.func", + "prod.db.tbl", + "ns1.ns2.tbl").foreach { sql => + parseMultipartIdentifier(sql) match { + case AsTableIdentifier(_) => + fail(s"$sql should not be resolved as TableIdentifier") + case _ => + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala new file mode 100644 index 0000000000000..35cd813ae65c5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} +import org.apache.spark.sql.types.{IntegerType, StringType, StructType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String + +class DDLParserSuite extends AnalysisTest { + import CatalystSqlParser._ + + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + private def parseCompare(sql: String, expected: LogicalPlan): Unit = { + comparePlans(parsePlan(sql), expected, checkAnalysis = false) + } + + test("create table using - schema") { + val sql = "CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("my_tab")) + assert(create.tableSchema == new StructType() + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + + intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", + "no viable alternative at input") + } + + test("create table - with IF NOT EXISTS") { + val sql = "CREATE TABLE IF NOT EXISTS my_tab(a INT, b STRING) USING parquet" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with partitioned by") { + val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + + "USING parquet PARTITIONED BY (a)" + + parsePlan(query) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("my_tab")) + assert(create.tableSchema == new StructType() + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType)) + assert(create.partitioning == Seq(IdentityTransform(FieldReference("a")))) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table - partitioned by transforms") { + val sql = + """ + |CREATE TABLE my_tab (a INT, b STRING, ts TIMESTAMP) USING parquet + |PARTITIONED BY ( + | a, + | bucket(16, b), + | years(ts), + | months(ts), + | days(ts), + | hours(ts), + | foo(a, "bar", 34)) + """.stripMargin + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("my_tab")) + assert(create.tableSchema == new StructType() + .add("a", IntegerType) + .add("b", StringType) + .add("ts", TimestampType)) + assert(create.partitioning == Seq( + IdentityTransform(FieldReference("a")), + BucketTransform(LiteralValue(16, IntegerType), Seq(FieldReference("b"))), + YearsTransform(FieldReference("ts")), + MonthsTransform(FieldReference("ts")), + DaysTransform(FieldReference("ts")), + HoursTransform(FieldReference("ts")), + ApplyTransform("foo", Seq( + FieldReference("a"), + LiteralValue(UTF8String.fromString("bar"), StringType), + LiteralValue(34, IntegerType))))) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with bucket") { + val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" + + parsePlan(query) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.contains(BucketSpec(5, Seq("a"), Seq("b")))) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table - with comment") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.contains("abc")) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with table properties") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties == Map("test" -> "test")) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with location") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.contains("/tmp/file")) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - byte length literal table name") { + val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("1m", "2g")) + assert(create.tableSchema == new StructType().add("a", IntegerType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String): String = { + s"CREATE TABLE my_tab(a INT, b STRING) USING parquet $duplicateClause $duplicateClause" + } + + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'"), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'"), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"), + "Found duplicate clauses: CLUSTERED BY") + intercept(createTableHeader("PARTITIONED BY (b)"), + "Found duplicate clauses: PARTITIONED BY") + } + + test("support for other types in OPTIONS") { + val sql = + """ + |CREATE TABLE table_name USING json + |OPTIONS (a 1, b 0.1, c TRUE) + """.stripMargin + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.tableName == Seq("table_name")) + assert(create.tableSchema == new StructType) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "json") + assert(create.options == Map("a" -> "1", "b" -> "0.1", "c" -> "true")) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test CTAS against native tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + parsePlan(sql) match { + case create: CreateTableAsSelectStatement => + assert(create.tableName == Seq("mydb", "page_view")) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties == Map("p1" -> "v1", "p2" -> "v2")) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.contains("/user/external/page_view")) + assert(create.comment.contains("This is the staging page view table")) + assert(create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableAsSelectStatement].getClass.getName} " + + s"from query, got ${other.getClass.getName}: $sql") + } + } + } + + test("drop table") { + parseCompare("DROP TABLE testcat.ns1.ns2.tbl", + DropTableStatement(Seq("testcat", "ns1", "ns2", "tbl"), ifExists = false, purge = false)) + parseCompare(s"DROP TABLE db.tab", + DropTableStatement(Seq("db", "tab"), ifExists = false, purge = false)) + parseCompare(s"DROP TABLE IF EXISTS db.tab", + DropTableStatement(Seq("db", "tab"), ifExists = true, purge = false)) + parseCompare(s"DROP TABLE tab", + DropTableStatement(Seq("tab"), ifExists = false, purge = false)) + parseCompare(s"DROP TABLE IF EXISTS tab", + DropTableStatement(Seq("tab"), ifExists = true, purge = false)) + parseCompare(s"DROP TABLE tab PURGE", + DropTableStatement(Seq("tab"), ifExists = false, purge = true)) + parseCompare(s"DROP TABLE IF EXISTS tab PURGE", + DropTableStatement(Seq("tab"), ifExists = true, purge = true)) + } + + test("drop view") { + parseCompare(s"DROP VIEW testcat.db.view", + DropViewStatement(Seq("testcat", "db", "view"), ifExists = false)) + parseCompare(s"DROP VIEW db.view", DropViewStatement(Seq("db", "view"), ifExists = false)) + parseCompare(s"DROP VIEW IF EXISTS db.view", + DropViewStatement(Seq("db", "view"), ifExists = true)) + parseCompare(s"DROP VIEW view", DropViewStatement(Seq("view"), ifExists = false)) + parseCompare(s"DROP VIEW IF EXISTS view", DropViewStatement(Seq("view"), ifExists = true)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index ff0de0fb7c1f0..489b7f328f8fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -47,8 +47,8 @@ class TableIdentifierParserSuite extends SparkFunSuite { "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", - "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", - "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", + "of", "order", "out", "outer", "partition", "percent", "procedure", "query", "range", "reads", + "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala index d801f62b62323..4439a7bb3ae87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.arrow +package org.apache.spark.sql.util -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala similarity index 53% rename from sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala index cfa69a86de1a7..0accb471cada3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala @@ -15,31 +15,38 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2 +package org.apache.spark.sql.util + +import java.util import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite -/** - * A simple test suite to verify `DataSourceOptions`. - */ -class DataSourceOptionsSuite extends SparkFunSuite { +class CaseInsensitiveStringMapSuite extends SparkFunSuite { + + test("put and get") { + val options = CaseInsensitiveStringMap.empty() + intercept[UnsupportedOperationException] { + options.put("kEy", "valUE") + } + } - test("key is case-insensitive") { - val options = new DataSourceOptions(Map("foo" -> "bar").asJava) - assert(options.get("foo").get() == "bar") - assert(options.get("FoO").get() == "bar") - assert(!options.get("abc").isPresent) + test("clear") { + val options = new CaseInsensitiveStringMap(Map("kEy" -> "valUE").asJava) + intercept[UnsupportedOperationException] { + options.clear() + } } - test("value is case-sensitive") { - val options = new DataSourceOptions(Map("foo" -> "bAr").asJava) - assert(options.get("foo").get == "bAr") + test("key and value set") { + val options = new CaseInsensitiveStringMap(Map("kEy" -> "valUE").asJava) + assert(options.keySet().asScala == Set("key")) + assert(options.values().asScala.toSeq == Seq("valUE")) } test("getInt") { - val options = new DataSourceOptions(Map("numFOo" -> "1", "foo" -> "bar").asJava) + val options = new CaseInsensitiveStringMap(Map("numFOo" -> "1", "foo" -> "bar").asJava) assert(options.getInt("numFOO", 10) == 1) assert(options.getInt("numFOO2", 10) == 10) @@ -49,17 +56,20 @@ class DataSourceOptionsSuite extends SparkFunSuite { } test("getBoolean") { - val options = new DataSourceOptions( + val options = new CaseInsensitiveStringMap( Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) assert(options.getBoolean("isFoo", false)) assert(!options.getBoolean("isFoo2", true)) assert(options.getBoolean("isBar", true)) assert(!options.getBoolean("isBar", false)) - assert(!options.getBoolean("FOO", true)) + + intercept[IllegalArgumentException] { + options.getBoolean("FOO", true) + } } test("getLong") { - val options = new DataSourceOptions(Map("numFoo" -> "9223372036854775807", + val options = new CaseInsensitiveStringMap(Map("numFoo" -> "9223372036854775807", "foo" -> "bar").asJava) assert(options.getLong("numFOO", 0L) == 9223372036854775807L) assert(options.getLong("numFoo2", -1L) == -1L) @@ -70,7 +80,7 @@ class DataSourceOptionsSuite extends SparkFunSuite { } test("getDouble") { - val options = new DataSourceOptions(Map("numFoo" -> "922337.1", + val options = new CaseInsensitiveStringMap(Map("numFoo" -> "922337.1", "foo" -> "bar").asJava) assert(options.getDouble("numFOO", 0d) == 922337.1d) assert(options.getDouble("numFoo2", -1.02d) == -1.02d) @@ -80,28 +90,19 @@ class DataSourceOptionsSuite extends SparkFunSuite { } } - test("standard options") { - val options = new DataSourceOptions(Map( - DataSourceOptions.PATH_KEY -> "abc", - DataSourceOptions.TABLE_KEY -> "tbl").asJava) - - assert(options.paths().toSeq == Seq("abc")) - assert(options.tableName().get() == "tbl") - assert(!options.databaseName().isPresent) - } - - test("standard options with both singular path and multi-paths") { - val options = new DataSourceOptions(Map( - DataSourceOptions.PATH_KEY -> "abc", - DataSourceOptions.PATHS_KEY -> """["c", "d"]""").asJava) - - assert(options.paths().toSeq == Seq("abc", "c", "d")) - } - - test("standard options with only multi-paths") { - val options = new DataSourceOptions(Map( - DataSourceOptions.PATHS_KEY -> """["c", "d\"e"]""").asJava) + test("asCaseSensitiveMap") { + val originalMap = new util.HashMap[String, String] { + put("Foo", "Bar") + put("OFO", "ABR") + put("OoF", "bar") + } - assert(options.paths().toSeq == Seq("c", "d\"e")) + val options = new CaseInsensitiveStringMap(originalMap) + val caseSensitiveMap = options.asCaseSensitiveMap + assert(caseSensitiveMap.equals(originalMap)) + // The result of `asCaseSensitiveMap` is read-only. + intercept[UnsupportedOperationException] { + caseSensitiveMap.put("kEy", "valUE") + } } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 95e98c5444721..6f0db3632d7dd 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -112,10 +112,6 @@ com.fasterxml.jackson.core jackson-databind - - org.apache.arrow - arrow-vector - org.apache.xbean xbean-asm7-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java similarity index 68% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java rename to sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java index c44b8af2552f0..7c167dc012329 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java @@ -18,12 +18,10 @@ package org.apache.spark.sql.execution.streaming; /** - * The shared interface between V1 streaming sources and V2 streaming readers. + * This class is an alias of {@link org.apache.spark.sql.sources.v2.reader.streaming.Offset}. It's + * internal and deprecated. New streaming data source implementations should use data source v2 API, + * which will be supported in the long term. * - * This is a temporary interface for compatibility during migration. It should not be implemented - * directly, and will be removed in future versions. + * This class will be removed in a future release. */ -public interface BaseStreamingSource { - /** Stop this source and free any resources it has allocated. */ - void stop(); -} +public abstract class Offset extends org.apache.spark.sql.sources.v2.reader.streaming.Offset {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 4e4242fe8d9b9..fca7e36859126 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -26,7 +26,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.sql.vectorized.ColumnarRow; -import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -39,17 +38,10 @@ */ public final class MutableColumnarRow extends InternalRow { public int rowId; - private final ColumnVector[] columns; - private final WritableColumnVector[] writableColumns; - - public MutableColumnarRow(ColumnVector[] columns) { - this.columns = columns; - this.writableColumns = null; - } + private final WritableColumnVector[] columns; public MutableColumnarRow(WritableColumnVector[] writableColumns) { this.columns = writableColumns; - this.writableColumns = writableColumns; } @Override @@ -228,54 +220,54 @@ public void update(int ordinal, Object value) { @Override public void setNullAt(int ordinal) { - writableColumns[ordinal].putNull(rowId); + columns[ordinal].putNull(rowId); } @Override public void setBoolean(int ordinal, boolean value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putBoolean(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putBoolean(rowId, value); } @Override public void setByte(int ordinal, byte value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putByte(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putByte(rowId, value); } @Override public void setShort(int ordinal, short value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putShort(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putShort(rowId, value); } @Override public void setInt(int ordinal, int value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putInt(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putInt(rowId, value); } @Override public void setLong(int ordinal, long value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putLong(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putLong(rowId, value); } @Override public void setFloat(int ordinal, float value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putFloat(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putFloat(rowId, value); } @Override public void setDouble(int ordinal, double value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putDouble(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDouble(rowId, value); } @Override public void setDecimal(int ordinal, Decimal value, int precision) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putDecimal(rowId, value, precision); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDecimal(rowId, value, precision); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java deleted file mode 100644 index 00af0bf1b172c..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Stream; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import org.apache.spark.annotation.Evolving; - -/** - * An immutable string-to-string map in which keys are case-insensitive. This is used to represent - * data source options. - * - * Each data source implementation can define its own options and teach its users how to set them. - * Spark doesn't have any restrictions about what options a data source should or should not have. - * Instead Spark defines some standard options that data sources can optionally adopt. It's possible - * that some options are very common and many data sources use them. However different data - * sources may define the common options(key and meaning) differently, which is quite confusing to - * end users. - * - * The standard options defined by Spark: - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - *
Option keyOption value
pathA path string of the data files/directories, like - * path1, /absolute/file2, path3/*. The path can - * either be relative or absolute, points to either file or directory, and can contain - * wildcards. This option is commonly used by file-based data sources.
pathsA JSON array style paths string of the data files/directories, like - * ["path1", "/absolute/file2"]. The format of each path is same as the - * path option, plus it should follow JSON string literal format, e.g. quotes - * should be escaped, pa\"th means pa"th. - *
tableA table name string representing the table name directly without any interpretation. - * For example, db.tbl means a table called db.tbl, not a table called tbl - * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l.
databaseA database name string representing the database name directly without any - * interpretation, which is very similar to the table name option.
- */ -@Evolving -public class DataSourceOptions { - private final Map keyLowerCasedMap; - - private String toLowerCase(String key) { - return key.toLowerCase(Locale.ROOT); - } - - public static DataSourceOptions empty() { - return new DataSourceOptions(new HashMap<>()); - } - - public DataSourceOptions(Map originalMap) { - keyLowerCasedMap = new HashMap<>(originalMap.size()); - for (Map.Entry entry : originalMap.entrySet()) { - keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); - } - } - - public Map asMap() { - return new HashMap<>(keyLowerCasedMap); - } - - /** - * Returns the option value to which the specified key is mapped, case-insensitively. - */ - public Optional get(String key) { - return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); - } - - /** - * Returns the boolean value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public boolean getBoolean(String key, boolean defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the integer value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public int getInt(String key, int defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the long value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public long getLong(String key, long defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the double value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public double getDouble(String key, double defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * The option key for singular path. - */ - public static final String PATH_KEY = "path"; - - /** - * The option key for multiple paths. - */ - public static final String PATHS_KEY = "paths"; - - /** - * The option key for table name. - */ - public static final String TABLE_KEY = "table"; - - /** - * The option key for database name. - */ - public static final String DATABASE_KEY = "database"; - - /** - * The option key for whether to check existence of files for a table. - */ - public static final String CHECK_FILES_EXIST_KEY = "check_files_exist"; - - /** - * Returns all the paths specified by both the singular path option and the multiple - * paths option. - */ - public String[] paths() { - String[] singularPath = - get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]); - Optional pathsStr = get(PATHS_KEY); - if (pathsStr.isPresent()) { - ObjectMapper objectMapper = new ObjectMapper(); - try { - String[] paths = objectMapper.readValue(pathsStr.get(), String[].class); - return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new); - } catch (IOException e) { - return singularPath; - } - } else { - return singularPath; - } - } - - /** - * Returns the value of the table name option. - */ - public Optional tableName() { - return get(TABLE_KEY); - } - - /** - * Returns the value of the database name option. - */ - public Optional databaseName() { - return get(DATABASE_KEY); - } - - public Boolean checkFilesExist() { - Optional result = get(CHECK_FILES_EXIST_KEY); - return result.isPresent() && result.get().equals("true"); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java deleted file mode 100644 index 8ac9c51750865..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - * - * This interface is used to create {@link StreamingWriteSupport} instances when end users run - * {@code Dataset.writeStream.format(...).option(...).start()}. - */ -@Evolving -public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { - - /** - * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is - * called by Spark at the beginning of each streaming query. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link StreamingWriteSupport} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamingWriteSupport createStreamingWriteSupport( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java deleted file mode 100644 index 6c5a95d2a75b7..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports batch scan. - *

- * If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that - * builds {@link Scan} with {@link Scan#toBatch()} implemented. - *

- */ -@Evolving -public interface SupportsBatchRead extends SupportsRead { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java deleted file mode 100644 index 08caadd5308e6..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.writer.WriteBuilder; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports batch write. - *

- * If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} - * with {@link WriteBuilder#buildForBatch()} implemented. - *

- */ -@Evolving -public interface SupportsBatchWrite extends SupportsWrite {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java deleted file mode 100644 index 07546a54013ec..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.vectorized; - -import java.util.*; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; - -/** - * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this - * batch so that Spark can access the data row by row. Instance of it is meant to be reused during - * the entire data loading process. - */ -@Evolving -public final class ColumnarBatch { - private int numRows; - private final ColumnVector[] columns; - - // Staging row returned from `getRow`. - private final MutableColumnarRow row; - - /** - * Called to close all the columns in this batch. It is not valid to access the data after - * calling this. This must be called at the end to clean up memory allocations. - */ - public void close() { - for (ColumnVector c: columns) { - c.close(); - } - } - - /** - * Returns an iterator over the rows in this batch. - */ - public Iterator rowIterator() { - final int maxRows = numRows; - final MutableColumnarRow row = new MutableColumnarRow(columns); - return new Iterator() { - int rowId = 0; - - @Override - public boolean hasNext() { - return rowId < maxRows; - } - - @Override - public InternalRow next() { - if (rowId >= maxRows) { - throw new NoSuchElementException(); - } - row.rowId = rowId++; - return row; - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; - } - - /** - * Sets the number of rows in this batch. - */ - public void setNumRows(int numRows) { - this.numRows = numRows; - } - - /** - * Returns the number of columns that make up this batch. - */ - public int numCols() { return columns.length; } - - /** - * Returns the number of rows for read, including filtered rows. - */ - public int numRows() { return numRows; } - - /** - * Returns the column at `ordinal`. - */ - public ColumnVector column(int ordinal) { return columns[ordinal]; } - - /** - * Returns the row in this batch at `rowId`. Returned row is reused across calls. - */ - public InternalRow getRow(int rowId) { - assert(rowId >= 0 && rowId < numRows); - row.rowId = rowId; - return row; - } - - public ColumnarBatch(ColumnVector[] columns) { - this.columns = columns; - this.row = new MutableColumnarRow(columns); - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a380a06cb942b..0cf9957539e73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,9 +37,11 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, FileTable} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2} import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String /** @@ -176,7 +178,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ def load(path: String): DataFrame = { // force invocation of `load(...varargs...)` - option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*) + option("path", path).load(Seq.empty: _*) } /** @@ -193,7 +195,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } val useV1Sources = - sparkSession.sessionState.conf.userV1SourceReaderList.toLowerCase(Locale.ROOT).split(",") + sparkSession.sessionState.conf.useV1SourceReaderList.toLowerCase(Locale.ROOT).split(",") val lookupCls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) val cls = lookupCls.newInstance() match { case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) || @@ -205,21 +207,25 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (classOf[TableProvider].isAssignableFrom(cls)) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = provider, conf = sparkSession.sessionState.conf) - val pathsOption = { + source = provider, conf = sparkSession.sessionState.conf) + val pathsOption = if (paths.isEmpty) { + None + } else { val objectMapper = new ObjectMapper() - DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) + Some("paths" -> objectMapper.writeValueAsString(paths.toArray)) } - val checkFilesExistsOption = DataSourceOptions.CHECK_FILES_EXIST_KEY -> "true" - val finalOptions = sessionOptions ++ extraOptions.toMap + pathsOption + checkFilesExistsOption - val dsOptions = new DataSourceOptions(finalOptions.asJava) + // TODO SPARK-27113: remove this option. + val checkFilesExistsOpt = "check_files_exist" -> "true" + val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption + checkFilesExistsOpt + val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val table = userSpecifiedSchema match { case Some(schema) => provider.getTable(dsOptions, schema) case _ => provider.getTable(dsOptions) } + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { - case _: SupportsBatchRead => - Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, finalOptions)) + case _: SupportsRead if table.supports(BATCH_READ) => + Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, dsOptions)) case _ => loadV1Source(paths: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 47fb548ecd43c..b87b3bd4f0761 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,15 +25,18 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, @@ -52,13 +55,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `SaveMode.Overwrite`: overwrite the existing data.
  • *
  • `SaveMode.Append`: append the data.
  • *
  • `SaveMode.Ignore`: ignore the operation (i.e. no-op).
  • - *
  • `SaveMode.ErrorIfExists`: default option, throw an exception at runtime.
  • + *
  • `SaveMode.ErrorIfExists`: throw an exception at runtime.
  • * + *

    + * When writing to data source v1, the default option is `ErrorIfExists`. When writing to data + * source v2, the default option is `Append`. * * @since 1.4.0 */ def mode(saveMode: SaveMode): DataFrameWriter[T] = { - this.mode = saveMode + this.mode = Some(saveMode) this } @@ -74,15 +80,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter[T] = { - this.mode = saveMode.toLowerCase(Locale.ROOT) match { - case "overwrite" => SaveMode.Overwrite - case "append" => SaveMode.Append - case "ignore" => SaveMode.Ignore - case "error" | "errorifexists" | "default" => SaveMode.ErrorIfExists + saveMode.toLowerCase(Locale.ROOT) match { + case "overwrite" => mode(SaveMode.Overwrite) + case "append" => mode(SaveMode.Append) + case "ignore" => mode(SaveMode.Ignore) + case "error" | "errorifexists" => mode(SaveMode.ErrorIfExists) + case "default" => this case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + "Accepted save modes are 'overwrite', 'append', 'ignore', 'error', 'errorifexists'.") } - this } /** @@ -244,7 +250,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val session = df.sparkSession val useV1Sources = - session.sessionState.conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",") + session.sessionState.conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",") val lookupCls = DataSource.lookupDataSource(source, session.sessionState.conf) val cls = lookupCls.newInstance() match { case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) || @@ -259,36 +265,48 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( provider, session.sessionState.conf) - val checkFilesExistsOption = DataSourceOptions.CHECK_FILES_EXIST_KEY -> "false" + // TODO SPARK-27113: remove this option. + val checkFilesExistsOption = "check_files_exist" -> "false" val options = sessionOptions ++ extraOptions + checkFilesExistsOption - val dsOptions = new DataSourceOptions(options.asJava) + val dsOptions = new CaseInsensitiveStringMap(options.asJava) + + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { - case table: SupportsBatchWrite => - if (mode == SaveMode.Append) { - val relation = DataSourceV2Relation.create(table, options) + // TODO (SPARK-27815): To not break existing tests, here we treat file source as a special + // case, and pass the save mode to file source directly. This hack should be removed. + case table: FileTable => + val write = table.newWriteBuilder(dsOptions).asInstanceOf[FileWriteBuilder] + .mode(modeForDSV1) // should not change default mode for file source. + .withQueryId(UUID.randomUUID().toString) + .withInputDataSchema(df.logicalPlan.schema) + .buildForBatch() + // The returned `Write` can be null, which indicates that we can skip writing. + if (write != null) { runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan) - } - } else { - val writeBuilder = table.newWriteBuilder(dsOptions) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(df.logicalPlan.schema) - writeBuilder match { - case s: SupportsSaveMode => - val write = s.mode(mode).buildForBatch() - // It can only return null with `SupportsSaveMode`. We can clean it up after - // removing `SupportsSaveMode`. - if (write != null) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(write, df.logicalPlan) - } - } - - case _ => throw new AnalysisException( - s"data source ${table.name} does not support SaveMode $mode") + WriteToDataSourceV2(write, df.logicalPlan) } } + case table: SupportsWrite if table.supports(BATCH_WRITE) => + lazy val relation = DataSourceV2Relation.create(table, dsOptions) + modeForDSV2 match { + case SaveMode.Append => + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan) + } + + case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => + // truncate the table + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + } + + case other => + throw new AnalysisException(s"TableProvider implementation $source cannot be " + + s"written with $other mode, please use Append or Overwrite " + + "modes instead.") + } + // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving // as though it's a V1 source. @@ -306,7 +324,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + options = extraOptions.toMap).planForWriting(modeForDSV1, df.logicalPlan) } } @@ -355,7 +373,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], query = df.logicalPlan, - overwrite = mode == SaveMode.Overwrite, + overwrite = modeForDSV1 == SaveMode.Overwrite, ifPartitionNotExists = false) } } @@ -435,7 +453,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val tableIdentWithDB = tableIdent.copy(database = Some(db)) val tableName = tableIdentWithDB.unquotedString - (tableExists, mode) match { + (tableExists, modeForDSV1) match { case (true, SaveMode.Ignore) => // Do nothing @@ -490,7 +508,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) + runCommand(df.sparkSession, "saveAsTable")( + CreateTable(tableDesc, modeForDSV1, Some(df.logicalPlan))) } /** @@ -696,13 +715,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd) } + private def modeForDSV1 = mode.getOrElse(SaveMode.ErrorIfExists) + + private def modeForDSV2 = mode.getOrElse(SaveMode.Append) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName - private var mode: SaveMode = SaveMode.ErrorIfExists + private var mode: Option[SaveMode] = None private val extraOptions = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c5d14dfffd9b2..ff5ca2ac1111a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,6 +21,7 @@ import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -31,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog +import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Catalogs} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders._ @@ -619,6 +621,12 @@ class SparkSession private( */ @transient lazy val catalog: Catalog = new CatalogImpl(self) + @transient private lazy val catalogs = new mutable.HashMap[String, CatalogPlugin]() + + private[sql] def catalog(name: String): CatalogPlugin = synchronized { + catalogs.getOrElseUpdate(name, Catalogs.load(name, sessionState.conf)) + } + /** * Returns the specified table/view as a `DataFrame`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index c90b254a6d121..41cebc247a186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} +import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -35,7 +35,7 @@ object HiveResult { * `SparkSQLDriver` for CLI applications. */ def hiveResultString(executedPlan: SparkPlan): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand) => + case ExecutedCommandExec(_: DescribeCommandBase) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. executedPlan.executeCollectPublic().map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8deb55b00a9d3..ac61661e83e32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -370,127 +370,72 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + * Create a [[DescribeQueryCommand]] logical command. */ - type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + override def visitDescribeQuery(ctx: DescribeQueryContext): LogicalPlan = withOrigin(ctx) { + DescribeQueryCommand(visitQueryToDesc(ctx.queryToDesc())) + } /** - * Validate a create table statement and return the [[TableIdentifier]]. - */ - override def visitCreateTableHeader( - ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { - val temporary = ctx.TEMPORARY != null - val ifNotExists = ctx.EXISTS != null - if (temporary && ifNotExists) { - operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) + * Converts a multi-part identifier to a TableIdentifier. + * + * If the multi-part identifier has too many parts, this will throw a ParseException. + */ + def tableIdentifier( + multipart: Seq[String], + command: String, + ctx: ParserRuleContext): TableIdentifier = { + multipart match { + case Seq(tableName) => + TableIdentifier(tableName) + case Seq(database, tableName) => + TableIdentifier(tableName, Some(database)) + case _ => + operationNotAllowed(s"$command does not support multi-part identifiers", ctx) } - (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) } /** * Create a table, returning a [[CreateTable]] logical plan. * - * Expected format: - * {{{ - * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name - * USING table_provider - * create_table_clauses - * [[AS] select_statement]; + * This is used to produce CreateTempViewUsing from CREATE TEMPORARY TABLE. * - * create_table_clauses (order insensitive): - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] - * }}} + * TODO: Remove this. It is used because CreateTempViewUsing is not a Catalyst plan. + * Either move CreateTempViewUsing into catalyst as a parsed logical plan, or remove it because + * it is deprecated. */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { - val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - if (external) { - operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) - } + val (ident, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) - checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) - checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) - checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) - checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) - checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) - - val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) - val provider = ctx.tableProvider.qualifiedName.getText - val schema = Option(ctx.colTypeList()).map(createSchema) - val partitionColumnNames = - Option(ctx.partitionColumnNames) - .map(visitIdentifierList(_).toArray) - .getOrElse(Array.empty[String]) - val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - - val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) - val storage = DataSource.buildStorageFormatFromOptions(options) - - if (location.isDefined && storage.locationUri.isDefined) { - throw new ParseException( - "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + - "you can only specify one of them.", ctx) - } - val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI)) - - val tableType = if (customLocation.isDefined) { - CatalogTableType.EXTERNAL + if (!temp || ctx.query != null) { + super.visitCreateTable(ctx) } else { - CatalogTableType.MANAGED - } - - val tableDesc = CatalogTable( - identifier = table, - tableType = tableType, - storage = storage.copy(locationUri = customLocation), - schema = schema.getOrElse(new StructType), - provider = Some(provider), - partitionColumnNames = partitionColumnNames, - bucketSpec = bucketSpec, - properties = properties, - comment = Option(ctx.comment).map(string)) - - // Determine the storage mode. - val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists - - if (ctx.query != null) { - // Get the backing query. - val query = plan(ctx.query) - - if (temp) { - operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) + if (external) { + operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } - // Don't allow explicit specification of schema for CTAS - if (schema.nonEmpty) { - operationNotAllowed( - "Schema may not be specified in a Create Table As Select (CTAS) statement", - ctx) - } - CreateTable(tableDesc, mode, Some(query)) - } else { - if (temp) { - if (ifNotExists) { - operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) - } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) - logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + - "CREATE TEMPORARY VIEW ... USING ... instead") + if (ifNotExists) { // Unlike CREATE TEMPORARY VIEW USING, CREATE TEMPORARY TABLE USING does not support // IF NOT EXISTS. Users are not allowed to replace the existing temp table. - CreateTempViewUsing(table, schema, replace = false, global = false, provider, options) - } else { - CreateTable(tableDesc, mode, None) + operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) } + + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + val provider = ctx.tableProvider.qualifiedName.getText + val schema = Option(ctx.colTypeList()).map(createSchema) + + logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + + "CREATE TEMPORARY VIEW ... USING ... instead") + + val table = tableIdentifier(ident, "CREATE TEMPORARY VIEW", ctx) + CreateTempViewUsing(table, schema, replace = false, global = false, provider, options) } } @@ -555,77 +500,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "MSCK REPAIR TABLE") } - /** - * Convert a table property list into a key-value map. - * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. - */ - override def visitTablePropertyList( - ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { - val properties = ctx.tableProperty.asScala.map { property => - val key = visitTablePropertyKey(property.key) - val value = visitTablePropertyValue(property.value) - key -> value - } - // Check for duplicate property names. - checkDuplicateKeys(properties, ctx) - properties.toMap - } - - /** - * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. - */ - private def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { - val props = visitTablePropertyList(ctx) - val badKeys = props.collect { case (key, null) => key } - if (badKeys.nonEmpty) { - operationNotAllowed( - s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) - } - props - } - - /** - * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. - */ - private def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { - val props = visitTablePropertyList(ctx) - val badKeys = props.filter { case (_, v) => v != null }.keys - if (badKeys.nonEmpty) { - operationNotAllowed( - s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) - } - props.keys.toSeq - } - - /** - * A table property key can either be String or a collection of dot separated elements. This - * function extracts the property key based on whether its a string literal or a table property - * identifier. - */ - override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { - if (key.STRING != null) { - string(key.STRING) - } else { - key.getText - } - } - - /** - * A table property value can be String, Integer, Boolean or Decimal. This function extracts - * the property value based on whether its a string, integer, boolean or decimal literal. - */ - override def visitTablePropertyValue(value: TablePropertyValueContext): String = { - if (value == null) { - null - } else if (value.STRING != null) { - string(value.STRING) - } else if (value.booleanValue != null) { - value.getText.toLowerCase(Locale.ROOT) - } else { - value.getText - } - } - /** * Create a [[CreateDatabaseCommand]] command. * @@ -772,17 +646,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx.TEMPORARY != null) } - /** - * Create a [[DropTableCommand]] command. - */ - override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { - DropTableCommand( - visitTableIdentifier(ctx.tableIdentifier), - ctx.EXISTS != null, - ctx.VIEW != null, - ctx.PURGE != null) - } - /** * Create a [[AlterTableRenameCommand]] command. * @@ -999,34 +862,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { newColumn = visitColType(ctx.colType)) } - /** - * Create location string. - */ - override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { - string(ctx.STRING) - } - - /** - * Create a [[BucketSpec]]. - */ - override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { - BucketSpec( - ctx.INTEGER_VALUE.getText.toInt, - visitIdentifierList(ctx.identifierList), - Option(ctx.orderedIdentifierList) - .toSeq - .flatMap(_.orderedIdentifier.asScala) - .map { orderedIdCtx => - Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => - if (dir.toLowerCase(Locale.ROOT) != "asc") { - operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) - } - } - - orderedIdCtx.identifier.getText - }) - } - /** * Convert a nested constants list into a sequence of string sequences. */ @@ -1122,7 +957,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { - val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + val (ident, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) // TODO: implement temporary tables if (temp) { throw new ParseException( @@ -1180,6 +1015,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { CatalogTableType.MANAGED } + val name = tableIdentifier(ident, "CREATE TABLE ... STORED AS ...", ctx) + // TODO support the sql text - have a proper location for this! val tableDesc = CatalogTable( identifier = name, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index edfa70403ad15..e72ddf13f1668 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 +import org.apache.spark.sql.execution.streaming.sources.MemoryPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} import org.apache.spark.sql.types.StructType @@ -557,9 +557,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => - val encoder = RowEncoder(sink.schema) - LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil - case MemoryPlanV2(sink, output) => val encoder = RowEncoder(StructType.fromAttributes(output)) LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 2bf6a58b55658..4b692aaeb1e63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.{ByteBufferOutputStream, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 8dd484af6e908..6147d6fefd52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -25,6 +25,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils object ArrowWriter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index d24e66e583857..8b70e336c14bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -29,12 +29,12 @@ import org.apache.hadoop.fs.{FileContext, FsConstants, Path} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.Histogram +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -494,6 +494,34 @@ case class TruncateTableCommand( } } +abstract class DescribeCommandBase extends RunnableCommand { + override val output: Seq[Attribute] = Seq( + // Column names are based on Hive. + AttributeReference("col_name", StringType, nullable = false, + new MetadataBuilder().putString("comment", "name of the column").build())(), + AttributeReference("data_type", StringType, nullable = false, + new MetadataBuilder().putString("comment", "data type of the column").build())(), + AttributeReference("comment", StringType, nullable = true, + new MetadataBuilder().putString("comment", "comment of the column").build())() + ) + + protected def describeSchema( + schema: StructType, + buffer: ArrayBuffer[Row], + header: Boolean): Unit = { + if (header) { + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + } + schema.foreach { column => + append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) + } + } + + protected def append( + buffer: ArrayBuffer[Row], column: String, dataType: String, comment: String): Unit = { + buffer += Row(column, dataType, comment) + } +} /** * Command that looks like * {{{ @@ -504,17 +532,7 @@ case class DescribeTableCommand( table: TableIdentifier, partitionSpec: TablePartitionSpec, isExtended: Boolean) - extends RunnableCommand { - - override val output: Seq[Attribute] = Seq( - // Column names are based on Hive. - AttributeReference("col_name", StringType, nullable = false, - new MetadataBuilder().putString("comment", "name of the column").build())(), - AttributeReference("data_type", StringType, nullable = false, - new MetadataBuilder().putString("comment", "data type of the column").build())(), - AttributeReference("comment", StringType, nullable = true, - new MetadataBuilder().putString("comment", "comment of the column").build())() - ) + extends DescribeCommandBase { override def run(sparkSession: SparkSession): Seq[Row] = { val result = new ArrayBuffer[Row] @@ -603,22 +621,31 @@ case class DescribeTableCommand( } table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) } +} - private def describeSchema( - schema: StructType, - buffer: ArrayBuffer[Row], - header: Boolean): Unit = { - if (header) { - append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) - } - schema.foreach { column => - append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) - } - } +/** + * Command that looks like + * {{{ + * DESCRIBE [QUERY] statement + * }}} + * + * Parameter 'statement' can be one of the following types : + * 1. SELECT statements + * 2. SELECT statements inside set operators (UNION, INTERSECT etc) + * 3. VALUES statement. + * 4. TABLE statement. Example : TABLE table_name + * 5. statements of the form 'FROM table SELECT *' + * + * TODO : support CTEs. + */ +case class DescribeQueryCommand(query: LogicalPlan) + extends DescribeCommandBase { - private def append( - buffer: ArrayBuffer[Row], column: String, dataType: String, comment: String): Unit = { - buffer += Row(column, dataType, comment) + override def run(sparkSession: SparkSession): Seq[Row] = { + val result = new ArrayBuffer[Row] + val queryExecution = sparkSession.sessionState.executePlan(query) + describeSchema(queryExecution.analyzed.schema, result, header = false) + result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b0548bc21156e..622ad3b559ebd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -348,7 +348,8 @@ case class DataSource( case (format: FileFormat, _) if FileStreamSink.hasMetadata( caseInsensitiveOptions.get("path").toSeq ++ paths, - sparkSession.sessionState.newHadoopConf()) => + sparkSession.sessionState.newHadoopConf(), + sparkSession.sessionState.conf) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema) val dataSchema = userSpecifiedSchema.orElse { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala new file mode 100644 index 0000000000000..19881f69f158c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.Locale + +import scala.collection.mutable + +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.DropTableCommand +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.TableProvider +import org.apache.spark.sql.types.StructType + +case class DataSourceResolution( + conf: SQLConf, + findCatalog: String => CatalogPlugin) + extends Rule[LogicalPlan] with CastSupport with LookupCatalog { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + override protected def lookupCatalog(name: String): CatalogPlugin = findCatalog(name) + + def defaultCatalog: Option[CatalogPlugin] = conf.defaultV2Catalog.map(findCatalog) + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case CreateTableStatement( + AsTableIdentifier(table), schema, partitionCols, bucketSpec, properties, + V1WriteProvider(provider), options, location, comment, ifNotExists) => + + val tableDesc = buildCatalogTable(table, schema, partitionCols, bucketSpec, properties, + provider, options, location, comment, ifNotExists) + val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists + + CreateTable(tableDesc, mode, None) + + case create: CreateTableStatement => + // the provider was not a v1 source, convert to a v2 plan + val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName + val catalog = maybeCatalog.orElse(defaultCatalog) + .getOrElse(throw new AnalysisException( + s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) + .asTableCatalog + convertCreateTable(catalog, identifier, create) + + case CreateTableAsSelectStatement( + AsTableIdentifier(table), query, partitionCols, bucketSpec, properties, + V1WriteProvider(provider), options, location, comment, ifNotExists) => + + val tableDesc = buildCatalogTable(table, new StructType, partitionCols, bucketSpec, + properties, provider, options, location, comment, ifNotExists) + val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists + + CreateTable(tableDesc, mode, Some(query)) + + case create: CreateTableAsSelectStatement => + // the provider was not a v1 source, convert to a v2 plan + val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName + val catalog = maybeCatalog.orElse(defaultCatalog) + .getOrElse(throw new AnalysisException( + s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) + .asTableCatalog + convertCTAS(catalog, identifier, create) + + case DropTableStatement(CatalogObjectIdentifier(Some(catalog), ident), ifExists, _) => + DropTable(catalog.asTableCatalog, ident, ifExists) + + case DropTableStatement(AsTableIdentifier(tableName), ifExists, purge) => + DropTableCommand(tableName, ifExists, isView = false, purge) + + case DropViewStatement(CatalogObjectIdentifier(Some(catalog), ident), _) => + throw new AnalysisException( + s"Can not specify catalog `${catalog.name}` for view $ident " + + s"because view support in catalog has not been implemented yet") + + case DropViewStatement(AsTableIdentifier(tableName), ifExists) => + DropTableCommand(tableName, ifExists, isView = true, purge = false) + } + + object V1WriteProvider { + private val v1WriteOverrideSet = + conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",").toSet + + def unapply(provider: String): Option[String] = { + if (v1WriteOverrideSet.contains(provider.toLowerCase(Locale.ROOT))) { + Some(provider) + } else { + lazy val providerClass = DataSource.lookupDataSource(provider, conf) + provider match { + case _ if classOf[TableProvider].isAssignableFrom(providerClass) => + None + case _ => + Some(provider) + } + } + } + } + + private def buildCatalogTable( + table: TableIdentifier, + schema: StructType, + partitioning: Seq[Transform], + bucketSpec: Option[BucketSpec], + properties: Map[String, String], + provider: String, + options: Map[String, String], + location: Option[String], + comment: Option[String], + ifNotExists: Boolean): CatalogTable = { + + val storage = DataSource.buildStorageFormatFromOptions(options) + if (location.isDefined && storage.locationUri.isDefined) { + throw new AnalysisException( + "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + + "you can only specify one of them.") + } + val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI)) + + val tableType = if (customLocation.isDefined) { + CatalogTableType.EXTERNAL + } else { + CatalogTableType.MANAGED + } + + CatalogTable( + identifier = table, + tableType = tableType, + storage = storage.copy(locationUri = customLocation), + schema = schema, + provider = Some(provider), + partitionColumnNames = partitioning.asPartitionColumns, + bucketSpec = bucketSpec, + properties = properties, + comment = comment) + } + + private def convertCTAS( + catalog: TableCatalog, + identifier: Identifier, + ctas: CreateTableAsSelectStatement): CreateTableAsSelect = { + // convert the bucket spec and add it as a transform + val partitioning = ctas.partitioning ++ ctas.bucketSpec.map(_.asTransform) + val properties = convertTableProperties( + ctas.properties, ctas.options, ctas.location, ctas.comment, ctas.provider) + + CreateTableAsSelect( + catalog, + identifier, + partitioning, + ctas.asSelect, + properties, + writeOptions = ctas.options.filterKeys(_ != "path"), + ignoreIfExists = ctas.ifNotExists) + } + + private def convertCreateTable( + catalog: TableCatalog, + identifier: Identifier, + create: CreateTableStatement): CreateV2Table = { + // convert the bucket spec and add it as a transform + val partitioning = create.partitioning ++ create.bucketSpec.map(_.asTransform) + val properties = convertTableProperties( + create.properties, create.options, create.location, create.comment, create.provider) + + CreateV2Table( + catalog, + identifier, + create.tableSchema, + partitioning, + properties, + ignoreIfExists = create.ifNotExists) + } + + private def convertTableProperties( + properties: Map[String, String], + options: Map[String, String], + location: Option[String], + comment: Option[String], + provider: String): Map[String, String] = { + if (options.contains("path") && location.isDefined) { + throw new AnalysisException( + "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + + "you can only specify one of them.") + } + + if ((options.contains("comment") || properties.contains("comment")) + && comment.isDefined) { + throw new AnalysisException( + "COMMENT and option/property 'comment' are both used to set the table comment, you can " + + "only specify one of them.") + } + + if (options.contains("provider") || properties.contains("provider")) { + throw new AnalysisException( + "USING and option/property 'provider' are both used to set the provider implementation, " + + "you can only specify one of them.") + } + + val filteredOptions = options.filterKeys(_ != "path") + + // create table properties from TBLPROPERTIES and OPTIONS clauses + val tableProperties = new mutable.HashMap[String, String]() + tableProperties ++= properties + tableProperties ++= filteredOptions + + // convert USING, LOCATION, and COMMENT clauses to table properties + tableProperties += ("provider" -> provider) + comment.map(text => tableProperties += ("comment" -> text)) + location.orElse(options.get("path")).map(loc => tableProperties += ("location" -> loc)) + + tableProperties.toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index b5cf8c9515bfb..b73dc30d6f23c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -426,6 +426,22 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with } object DataSourceStrategy { + /** + * The attribute name of predicate could be different than the one in schema in case of + * case insensitive, we should change them to match the one in schema, so we do not need to + * worry about case sensitivity anymore. + */ + protected[sql] def normalizeFilters( + filters: Seq[Expression], + attributes: Seq[AttributeReference]): Seq[Expression] = { + filters.filterNot(SubqueryExpression.hasSubquery).map { e => + e transform { + case a: AttributeReference => + a.withName(attributes.find(_.semanticEquals(a)).get.name) + } + } + } + /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * @@ -513,6 +529,12 @@ object DataSourceStrategy { case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => Some(sources.StringContains(a.name, v.toString)) + case expressions.Literal(true, BooleanType) => + Some(sources.AlwaysTrue) + + case expressions.Literal(false, BooleanType) => + Some(sources.AlwaysFalse) + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala index 254c09001f7ec..7c72495548e3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import scala.collection.JavaConverters._ + import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule @@ -33,10 +35,15 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable */ class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(d @DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) => + case i @ InsertIntoTable(d @ DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) => val v1FileFormat = new OrcFileFormat - val relation = HadoopFsRelation(table.getFileIndex, table.getFileIndex.partitionSchema, - table.schema(), None, v1FileFormat, d.options)(sparkSession) + val relation = HadoopFsRelation( + table.fileIndex, + table.fileIndex.partitionSchema, + table.schema(), + None, + v1FileFormat, + d.options.asScala.toMap)(sparkSession) i.copy(table = LogicalRelation(relation)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 62ab5c80d47cf..970cbda6355e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -147,15 +147,7 @@ object FileSourceStrategy extends Strategy with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) - // The attribute name of predicate could be different than the one in schema in case of - // case insensitive, we should change them to match the one in schema, so we do not need to - // worry about case sensitivity anymore. - val normalizedFilters = filters.filterNot(SubqueryExpression.hasSubquery).map { e => - e transform { - case a: AttributeReference => - a.withName(l.output.find(_.semanticEquals(a)).get.name) - } - } + val normalizedFilters = DataSourceStrategy.normalizeFilters(filters, l.output) val partitionColumns = l.resolve( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 452ebbbeb99c8..2d90fd594fa7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -17,43 +17,42 @@ package org.apache.spark.sql.execution.datasources.noop -import org.apache.spark.sql.SaveMode +import java.util + +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * This is no-op datasource. It does not do anything besides consuming its input. * This can be useful for benchmarking or to cache data without any additional overhead. */ -class NoopDataSource - extends DataSourceV2 - with TableProvider - with DataSourceRegister - with StreamingWriteSupportProvider { - +class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" - override def getTable(options: DataSourceOptions): Table = NoopTable - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = NoopStreamingWriteSupport + override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable } -private[noop] object NoopTable extends Table with SupportsBatchWrite { - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = NoopWriteBuilder +private[noop] object NoopTable extends Table with SupportsWrite { + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() + override def capabilities(): util.Set[TableCapability] = Set( + TableCapability.BATCH_WRITE, + TableCapability.TRUNCATE, + TableCapability.ACCEPT_ANY_SCHEMA, + TableCapability.STREAMING_WRITE).asJava } -private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsSaveMode { +private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate { + override def truncate(): WriteBuilder = this override def buildForBatch(): BatchWrite = NoopBatchWrite - override def mode(mode: SaveMode): WriteBuilder = this + override def buildForStreaming(): StreamingWrite = NoopStreamingWrite } private[noop] object NoopBatchWrite extends BatchWrite { @@ -72,7 +71,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] { override def abort(): Unit = {} } -private[noop] object NoopStreamingWriteSupport extends StreamingWriteSupport { +private[noop] object NoopStreamingWrite extends StreamingWrite { override def createStreamingWriterFactory(): StreamingDataWriterFactory = NoopStreamingDataWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -85,4 +84,3 @@ private[noop] object NoopStreamingDataWriterFactory extends StreamingDataWriterF taskId: Long, epochId: Long): DataWriter[InternalRow] = NoopWriter } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala new file mode 100644 index 0000000000000..f35758bf08c67 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.types.StructType + +case class CreateTableExec( + catalog: TableCatalog, + identifier: Identifier, + tableSchema: StructType, + partitioning: Seq[Transform], + tableProperties: Map[String, String], + ignoreIfExists: Boolean) extends LeafExecNode { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + override protected def doExecute(): RDD[InternalRow] = { + if (!catalog.tableExists(identifier)) { + try { + catalog.createTable(identifier, tableSchema, partitioning.toArray, tableProperties.asJava) + } catch { + case _: TableAlreadyExistsException if ignoreIfExists => + logWarning(s"Table ${identifier.quoted} was created concurrently. Ignoring.") + } + } else if (!ignoreIfExists) { + throw new TableAlreadyExistsException(identifier) + } + + sqlContext.sparkContext.parallelize(Seq.empty, 1) + } + + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala new file mode 100644 index 0000000000000..eed69cdc8cac6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability} + +object DataSourceV2Implicits { + implicit class TableHelper(table: Table) { + def asReadable: SupportsRead = { + table match { + case support: SupportsRead => + support + case _ => + throw new AnalysisException(s"Table does not support reads: ${table.name}") + } + } + + def asWritable: SupportsWrite = { + table match { + case support: SupportsWrite => + support + case _ => + throw new AnalysisException(s"Table does not support writes: ${table.name}") + } + } + + def supports(capability: TableCapability): Boolean = table.capabilities.contains(capability) + + def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 47cf26dc9481e..fc919439d9224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,20 +17,15 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.UUID - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.{Statistics => V2Statistics, _} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A logical plan representing a data source v2 table. @@ -42,29 +37,21 @@ import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( table: Table, output: Seq[AttributeReference], - options: Map[String, String]) + options: CaseInsensitiveStringMap) extends LeafNode with MultiInstanceRelation with NamedRelation { + import DataSourceV2Implicits._ + override def name: String = table.name() + override def skipSchemaResolution: Boolean = table.supports(TableCapability.ACCEPT_ANY_SCHEMA) + override def simpleString(maxFields: Int): String = { s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" } - def newScanBuilder(): ScanBuilder = table match { - case s: SupportsBatchRead => - val dsOptions = new DataSourceOptions(options.asJava) - s.newScanBuilder(dsOptions) - case _ => throw new AnalysisException(s"Table is not readable: ${table.name()}") - } - - def newWriteBuilder(schema: StructType): WriteBuilder = table match { - case s: SupportsBatchWrite => - val dsOptions = new DataSourceOptions(options.asJava) - s.newWriteBuilder(dsOptions) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(schema) - case _ => throw new AnalysisException(s"Table is not writable: ${table.name()}") + def newScanBuilder(): ScanBuilder = { + table.asReadable.newScanBuilder(options) } override def computeStats(): Statistics = { @@ -72,7 +59,7 @@ case class DataSourceV2Relation( scan match { case r: SupportsReportStatistics => val statistics = r.estimateStatistics() - Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -105,15 +92,32 @@ case class StreamingDataSourceV2Relation( override def computeStats(): Statistics = scan match { case r: SupportsReportStatistics => val statistics = r.estimateStatistics() - Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } } object DataSourceV2Relation { - def create(table: Table, options: Map[String, String]): DataSourceV2Relation = { + def create(table: Table, options: CaseInsensitiveStringMap): DataSourceV2Relation = { val output = table.schema().toAttributes DataSourceV2Relation(table, output, options) } + + /** + * This is used to transform data source v2 statistics to logical.Statistics. + */ + def transformV2Stats( + v2Statistics: V2Statistics, + defaultRowCount: Option[BigInt], + defaultSizeInBytes: Long): Statistics = { + val numRows: Option[BigInt] = if (v2Statistics.numRows().isPresent) { + Some(v2Statistics.numRows().getAsLong) + } else { + defaultRowCount + } + Statistics( + sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes), + rowCount = numRows) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 40ac5cf402987..9889fd6731565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,20 +17,22 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.sql.{sources, AnalysisException, SaveMode, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.{AnalysisException, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode +import org.apache.spark.sql.util.CaseInsensitiveStringMap -object DataSourceV2Strategy extends Strategy { +object DataSourceV2Strategy extends Strategy with PredicateHelper { /** * Pushes down filters to the data source reader @@ -100,14 +102,22 @@ object DataSourceV2Strategy extends Strategy { } } + import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.newScanBuilder() + + val (withSubquery, withoutSubquery) = filters.partition(SubqueryExpression.hasSubquery) + val normalizedFilters = DataSourceStrategy.normalizeFilters( + withoutSubquery, relation.output) + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(scanBuilder, filters) + val (pushedFilters, postScanFiltersWithoutSubquery) = + pushFilters(scanBuilder, normalizedFilters) + val postScanFilters = postScanFiltersWithoutSubquery ++ withSubquery val (scan, output) = pruneColumns(scanBuilder, relation, project ++ postScanFilters) logInfo( s""" @@ -142,15 +152,29 @@ object DataSourceV2Strategy extends Strategy { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case CreateV2Table(catalog, ident, schema, parts, props, ifNotExists) => + CreateTableExec(catalog, ident, schema, parts, props, ifNotExists) :: Nil + + case CreateTableAsSelect(catalog, ident, parts, query, props, options, ifNotExists) => + val writeOptions = new CaseInsensitiveStringMap(options.asJava) + CreateTableAsSelectExec( + catalog, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil + case AppendData(r: DataSourceV2Relation, query, _) => - val writeBuilder = r.newWriteBuilder(query.schema) - writeBuilder match { - case s: SupportsSaveMode => - val write = s.mode(SaveMode.Append).buildForBatch() - assert(write != null) - WriteToDataSourceV2Exec(write, planLater(query)) :: Nil - case _ => throw new AnalysisException(s"data source ${r.name} does not support SaveMode") - } + AppendDataExec(r.table.asWritable, r.options, planLater(query)) :: Nil + + case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => + // fail if any filter cannot be converted. correctness depends on removing all matching data. + val filters = splitConjunctivePredicates(deleteExpr).map { + filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( + throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) + }.toArray + + OverwriteByExpressionExec( + r.table.asWritable, filters, r.options, planLater(query)) :: Nil + + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => + OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil @@ -167,6 +191,9 @@ object DataSourceV2Strategy extends Strategy { Nil } + case DropTable(catalog, ident, ifExists) => + DropTableExec(catalog, ident, ifExists) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala deleted file mode 100644 index f11703c8a2773..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.util.Utils - -/** - * A trait that can be used by data source v2 related query plans(both logical and physical), to - * provide a string format of the data source information for explain. - */ -trait DataSourceV2StringFormat { - - /** - * The instance of this data source implementation. Note that we only consider its class in - * equals/hashCode, not the instance itself. - */ - def source: DataSourceV2 - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The options for this data source reader. - */ - def options: Map[String, String] - - /** - * The filters which have been pushed to the data source. - */ - def pushedFilters: Seq[Expression] - - private def sourceName: String = source match { - case registered: DataSourceRegister => registered.shortName() - // source.getClass.getSimpleName can cause Malformed class name error, - // call safer `Utils.getSimpleName` instead - case _ => Utils.getSimpleName(source.getClass) - } - - def metadataString(maxFields: Int): String = { - val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - - if (pushedFilters.nonEmpty) { - entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") - } - - // TODO: we should only display some standard options like path, table, etc. - if (options.nonEmpty) { - entries += "Options" -> Utils.redact(options).map { - case (k, v) => s"$k=$v" - }.mkString("[", ",", "]") - } - - val outputStr = truncatedString(output, "[", ", ", "]", maxFields) - - val entriesStr = if (entries.nonEmpty) { - truncatedString(entries.map { - case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) - }, " (", ", ", ")", maxFields) - } else { - "" - } - - s"$sourceName$outputStr$entriesStr" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index e9cc3991155c4..30897d86f8179 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,8 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} +import org.apache.spark.sql.sources.v2.{SessionConfigSupport, TableProvider} private[sql] object DataSourceV2Utils extends Logging { @@ -34,34 +33,28 @@ private[sql] object DataSourceV2Utils extends Logging { * `spark.datasource.$keyPrefix`. A session config `spark.datasource.$keyPrefix.xxx -> yyy` will * be transformed into `xxx -> yyy`. * - * @param ds a [[DataSourceV2]] object + * @param source a [[TableProvider]] object * @param conf the session conf * @return an immutable map that contains all the extracted and transformed k/v pairs. */ - def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match { - case cs: SessionConfigSupport => - val keyPrefix = cs.keyPrefix() - require(keyPrefix != null, "The data source config key prefix can't be null.") - - val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") - - conf.getAllConfs.flatMap { case (key, value) => - val m = pattern.matcher(key) - if (m.matches() && m.groupCount() > 0) { - Seq((m.group(1), value)) - } else { - Seq.empty + def extractSessionConfigs(source: TableProvider, conf: SQLConf): Map[String, String] = { + source match { + case cs: SessionConfigSupport => + val keyPrefix = cs.keyPrefix() + require(keyPrefix != null, "The data source config key prefix can't be null.") + + val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") + + conf.getAllConfs.flatMap { case (key, value) => + val m = pattern.matcher(key) + if (m.matches() && m.groupCount() > 0) { + Seq((m.group(1), value)) + } else { + Seq.empty + } } - } - - case _ => Map.empty - } - def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { - val name = ds match { - case register: DataSourceRegister => register.shortName() - case _ => ds.getClass.getName + case _ => Map.empty } - throw new UnsupportedOperationException(name + " source does not support user-specified schema") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala new file mode 100644 index 0000000000000..d325e0205f9d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.LeafExecNode + +/** + * Physical plan node for dropping a table. + */ +case class DropTableExec(catalog: TableCatalog, ident: Identifier, ifExists: Boolean) + extends LeafExecNode { + + override def doExecute(): RDD[InternalRow] = { + if (catalog.tableExists(ident)) { + catalog.dropTable(ident) + } else if (!ifExists) { + throw new NoSuchTableException(ident) + } + + sqlContext.sparkContext.parallelize(Seq.empty, 1) + } + + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index a0c932cbb0e09..e9c7a1bb749db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -16,13 +16,13 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ +import com.fasterxml.jackson.databind.ObjectMapper -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, SupportsBatchRead, TableProvider} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.TableProvider +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A base interface for data source v2 implementations of the built-in file-based data sources. @@ -39,16 +39,12 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { lazy val sparkSession = SparkSession.active - def getFileIndex( - options: DataSourceOptions, - userSpecifiedSchema: Option[StructType]): PartitioningAwareFileIndex = { - val filePaths = options.paths() - val hadoopConf = - sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) - val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, - checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist()) - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - new InMemoryFileIndex(sparkSession, rootPathsSpecified, - options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) + protected def getPaths(map: CaseInsensitiveStringMap): Seq[String] = { + val objectMapper = new ObjectMapper() + Option(map.get("paths")).map { pathStr => + objectMapper.readValue(pathStr, classOf[Array[String]]).toSeq + }.getOrElse { + Option(map.get("path")).toSeq + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 3615b15be6fd5..bdd6a48df20ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -18,15 +18,16 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.hadoop.fs.Path -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} abstract class FileScan( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex) extends Scan with Batch { + fileIndex: PartitioningAwareFileIndex, + readSchema: StructType) extends Scan with Batch { /** * Returns whether a file with `path` could be split or not. */ @@ -34,6 +35,22 @@ abstract class FileScan( false } + /** + * Returns whether this format supports the given [[DataType]] in write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String + protected def partitions: Seq[FilePartition] = { val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) @@ -57,5 +74,13 @@ abstract class FileScan( partitions.toArray } - override def toBatch: Batch = this + override def toBatch: Batch = { + readSchema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString} data type.") + } + } + this + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 0dbef145f7326..9cf292782ffe0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -16,19 +16,41 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import java.util + +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap abstract class FileTable( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, + options: CaseInsensitiveStringMap, + paths: Seq[String], userSpecifiedSchema: Option[StructType]) - extends Table with SupportsBatchRead with SupportsBatchWrite { - def getFileIndex: PartitioningAwareFileIndex = this.fileIndex + extends Table with SupportsRead with SupportsWrite { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + lazy val fileIndex: PartitioningAwareFileIndex = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + // This is an internal config so must be present. + val checkFilesExist = options.get("check_files_exist").toBoolean + val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf, + checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex( + sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache) + } lazy val dataSchema: StructType = userSpecifiedSchema.orElse { inferSchema(fileIndex.allFiles()) @@ -43,6 +65,12 @@ abstract class FileTable( fileIndex.partitionSchema, caseSensitive)._1 } + override def partitioning: Array[Transform] = fileIndex.partitionSchema.asTransforms + + override def properties: util.Map[String, String] = options.asCaseSensitiveMap + + override def capabilities: java.util.Set[TableCapability] = FileTable.CAPABILITIES + /** * When possible, this method should return the schema of the given `files`. When the format * does not support inference, or no valid files are given should return None. In these cases @@ -50,3 +78,7 @@ abstract class FileTable( */ def inferSchema(files: Seq[FileStatus]): Option[StructType] } + +object FileTable { + private val CAPABILITIES = Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index ce9b52f29d7bd..5375d965d1eff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import java.io.IOException import java.util.UUID import scala.collection.JavaConverters._ @@ -32,13 +33,16 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, WriteBuilder} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration -abstract class FileWriteBuilder(options: DataSourceOptions) - extends WriteBuilder with SupportsSaveMode { +abstract class FileWriteBuilder( + options: CaseInsensitiveStringMap, + paths: Seq[String], + _formatName: String, + supportsDataType: DataType => Boolean) extends WriteBuilder { private var schema: StructType = _ private var queryId: String = _ private var mode: SaveMode = _ @@ -53,25 +57,25 @@ abstract class FileWriteBuilder(options: DataSourceOptions) this } - override def mode(mode: SaveMode): WriteBuilder = { + def mode(mode: SaveMode): WriteBuilder = { this.mode = mode this } override def buildForBatch(): BatchWrite = { validateInputs() - val pathName = options.paths().head - val path = new Path(pathName) + val path = new Path(paths.head) val sparkSession = SparkSession.active - val optionsAsScala = options.asMap().asScala.toMap - val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala) + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) val job = getJobInstance(hadoopConf, path) val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = pathName) + outputPath = paths.head) lazy val description = - createWriteJobDescription(sparkSession, hadoopConf, job, pathName, optionsAsScala) + createWriteJobDescription(sparkSession, hadoopConf, job, paths.head, options.asScala.toMap) val fs = path.getFileSystem(hadoopConf) mode match { @@ -83,7 +87,9 @@ abstract class FileWriteBuilder(options: DataSourceOptions) null case SaveMode.Overwrite => - committer.deleteWithJob(fs, path, true) + if (fs.exists(path) && !committer.deleteWithJob(fs, path, true)) { + throw new IOException(s"Unable to clear directory $path prior to writing to it") + } committer.setupJob(job) new FileBatchWrite(job, description, committer) @@ -104,12 +110,35 @@ abstract class FileWriteBuilder(options: DataSourceOptions) options: Map[String, String], dataSchema: StructType): OutputWriterFactory + /** + * Returns whether this format supports the given [[DataType]] in write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String + private def validateInputs(): Unit = { assert(schema != null, "Missing input data schema") assert(queryId != null, "Missing query ID") assert(mode != null, "Missing save mode") - assert(options.paths().length == 1) + assert(paths.length == 1) DataSource.validateSchema(schema) + schema.foreach { field => + if (!supportsDataType.apply(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString}" + + s" data type.") + } + } } private def getJobInstance(hadoopConf: Configuration, path: Path): Job = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala new file mode 100644 index 0000000000000..c029acc0bb2df --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.v2.TableCapability.{CONTINUOUS_READ, MICRO_BATCH_READ} + +/** + * This rules adds some basic table capability check for streaming scan, without knowing the actual + * streaming execution mode. + */ +object V2StreamingScanSupportCheck extends (LogicalPlan => Unit) { + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): Unit = { + plan.foreach { + case r: StreamingRelationV2 if !r.table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => + throw new AnalysisException( + s"Table ${r.table.name()} does not support either micro-batch or continuous scan.") + case _ => + } + + val streamingSources = plan.collect { + case r: StreamingRelationV2 => r.table + } + val v1StreamingRelations = plan.collect { + case r: StreamingRelation => r + } + + if (streamingSources.length + v1StreamingRelations.length > 1) { + val allSupportsMicroBatch = streamingSources.forall(_.supports(MICRO_BATCH_READ)) + // v1 streaming data source only supports micro-batch. + val allSupportsContinuous = streamingSources.forall(_.supports(CONTINUOUS_READ)) && + v1StreamingRelations.isEmpty + if (!allSupportsMicroBatch && !allSupportsContinuous) { + val microBatchSources = + streamingSources.filter(_.supports(MICRO_BATCH_READ)).map(_.name()) ++ + v1StreamingRelations.map(_.sourceName) + val continuousSources = streamingSources.filter(_.supports(CONTINUOUS_READ)).map(_.name()) + throw new AnalysisException( + "The streaming sources in a query do not have a common supported execution mode.\n" + + "Sources support micro-batch: " + microBatchSources.mkString(", ") + "\n" + + "Sources support continuous: " + continuousSources.mkString(", ")) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala new file mode 100644 index 0000000000000..cf77998c122f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.sources.v2.TableCapability._ +import org.apache.spark.sql.types.BooleanType + +object V2WriteSupportCheck extends (LogicalPlan => Unit) { + import DataSourceV2Implicits._ + + def failAnalysis(msg: String): Unit = throw new AnalysisException(msg) + + override def apply(plan: LogicalPlan): Unit = plan foreach { + case AppendData(rel: DataSourceV2Relation, _, _) if !rel.table.supports(BATCH_WRITE) => + failAnalysis(s"Table does not support append in batch mode: ${rel.table}") + + case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _) + if !rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_DYNAMIC) => + failAnalysis(s"Table does not support dynamic overwrite in batch mode: ${rel.table}") + + case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _) => + expr match { + case Literal(true, BooleanType) => + if (!rel.table.supports(BATCH_WRITE) || + !rel.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) { + failAnalysis( + s"Table does not support truncate in batch mode: ${rel.table}") + } + case _ => + if (!rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_BY_FILTER)) { + failAnalysis(s"Table does not support overwrite expression ${expr.sql} " + + s"in batch mode: ${rel.table}") + } + } + + case _ => // OK + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 50c5e4f2ad7df..6c771ea988324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -17,17 +17,26 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.sources.v2.SupportsWrite +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -42,17 +51,170 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) } /** - * The physical plan for writing data into data source v2. + * Physical plan node for v2 create table as select. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If either table creation or the append fails, the table will be deleted. This implementation does + * not provide an atomic CTAS. + */ +case class CreateTableAsSelectExec( + catalog: TableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + query: SparkPlan, + properties: Map[String, String], + writeOptions: CaseInsensitiveStringMap, + ifNotExists: Boolean) extends V2TableWriteExec { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper + + override protected def doExecute(): RDD[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return sparkContext.parallelize(Seq.empty, 1) + } + + throw new TableAlreadyExistsException(ident) + } + + Utils.tryWithSafeFinallyAndFailureCallbacks({ + catalog.createTable(ident, query.schema, partitioning.toArray, properties.asJava) match { + case table: SupportsWrite => + val batchWrite = table.newWriteBuilder(writeOptions) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) + .buildForBatch() + + doWrite(batchWrite) + + case _ => + // table does not support writes + throw new SparkException(s"Table implementation does not support writes: ${ident.quoted}") + } + + })(catchBlock = { + catalog.dropTable(ident) + }) + } +} + +/** + * Physical plan node for append into a v2 table. + * + * Rows in the output data set are appended. */ -case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan) - extends UnaryExecNode { +case class AppendDataExec( + table: SupportsWrite, + writeOptions: CaseInsensitiveStringMap, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder().buildForBatch() + doWrite(batchWrite) + } +} + +/** + * Physical plan node for overwrite into a v2 table. + * + * Overwrites data in a table matched by a set of filters. Rows matching all of the filters will be + * deleted and rows in the output data set are appended. + * + * This plan is used to implement SaveMode.Overwrite. The behavior of SaveMode.Overwrite is to + * truncate the table -- delete all rows -- and append the output data set. This uses the filter + * AlwaysTrue to delete all rows. + */ +case class OverwriteByExpressionExec( + table: SupportsWrite, + deleteWhere: Array[Filter], + writeOptions: CaseInsensitiveStringMap, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsTruncate if isTruncate(deleteWhere) => + builder.truncate().buildForBatch() + + case builder: SupportsOverwrite => + builder.overwrite(deleteWhere).buildForBatch() + + case _ => + throw new SparkException(s"Table does not support overwrite by expression: $table") + } + + doWrite(batchWrite) + } +} + +/** + * Physical plan node for dynamic partition overwrite into a v2 table. + * + * Dynamic partition overwrite is the behavior of Hive INSERT OVERWRITE ... PARTITION queries, and + * Spark INSERT OVERWRITE queries when spark.sql.sources.partitionOverwriteMode=dynamic. Each + * partition in the output data set replaces the corresponding existing partition in the table or + * creates a new partition. Existing partitions for which there is no data in the output data set + * are not modified. + */ +case class OverwritePartitionsDynamicExec( + table: SupportsWrite, + writeOptions: CaseInsensitiveStringMap, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsDynamicOverwrite => + builder.overwriteDynamicPartitions().buildForBatch() + + case _ => + throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + } + + doWrite(batchWrite) + } +} + +case class WriteToDataSourceV2Exec( + batchWrite: BatchWrite, + query: SparkPlan) extends V2TableWriteExec { + + def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty() + + override protected def doExecute(): RDD[InternalRow] = { + doWrite(batchWrite) + } +} + +/** + * Helper for physical plans that build batch writes. + */ +trait BatchWriteHelper { + def table: SupportsWrite + def query: SparkPlan + def writeOptions: CaseInsensitiveStringMap + + def newWriteBuilder(): WriteBuilder = { + table.newWriteBuilder(writeOptions) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) + } +} + +/** + * The base physical plan for writing data into data source v2. + */ +trait V2TableWriteExec extends UnaryExecNode { + def query: SparkPlan var commitProgress: Option[StreamWriterCommitProgress] = None override def child: SparkPlan = query override def output: Seq[Attribute] = Nil - override protected def doExecute(): RDD[InternalRow] = { + protected def doWrite(batchWrite: BatchWrite): RDD[InternalRow] = { val writerFactory = batchWrite.createBatchWriterFactory() val useCommitCoordinator = batchWrite.useCommitCoordinator val rdd = query.execute() @@ -169,8 +331,8 @@ object DataWritingSparkTask extends Logging { } private[v2] case class DataWritingSparkTaskResult( - numRows: Long, - writerCommitMessage: WriterCommitMessage) + numRows: Long, + writerCommitMessage: WriterCommitMessage) /** * Sink progress information collected after commit. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index db1f2f7934221..900c94e937ffc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2.orc import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap class OrcDataSourceV2 extends FileDataSourceV2 { @@ -28,19 +29,36 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "orc" - private def getTableName(options: DataSourceOptions): String = { - shortName() + ":" + options.paths().mkString(";") + private def getTableName(paths: Seq[String]): String = { + shortName() + ":" + paths.mkString(";") } - override def getTable(options: DataSourceOptions): Table = { - val tableName = getTableName(options) - val fileIndex = getFileIndex(options, None) - OrcTable(tableName, sparkSession, fileIndex, None) + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + OrcTable(tableName, sparkSession, options, paths, None) } - override def getTable(options: DataSourceOptions, schema: StructType): Table = { - val tableName = getTableName(options) - val fileIndex = getFileIndex(options, Some(schema)) - OrcTable(tableName, sparkSession, fileIndex, Some(schema)) + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + OrcTable(tableName, sparkSession, options, paths, Some(schema)) + } +} + +object OrcDataSourceV2 { + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index a792ad318b398..3c5dc1f50d7e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration case class OrcScan( @@ -31,7 +31,7 @@ case class OrcScan( hadoopConf: Configuration, fileIndex: PartitioningAwareFileIndex, dataSchema: StructType, - readSchema: StructType) extends FileScan(sparkSession, fileIndex) { + readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -40,4 +40,10 @@ case class OrcScan( OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, fileIndex.partitionSchema, readSchema) } + + override def supportsDataType(dataType: DataType): Boolean = { + OrcDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index eb27bbd3abeaa..a2c55e8c43021 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -26,18 +26,21 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.Scan import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcScanBuilder( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, schema: StructType, dataSchema: StructType, - options: DataSourceOptions) extends FileScanBuilder(schema) { - lazy val hadoopConf = - sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) + options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) { + lazy val hadoopConf = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + } override def build(): Scan = { OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index b467e505f1bac..aac38fb3fa1ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -19,25 +19,26 @@ package org.apache.spark.sql.execution.datasources.v2.orc import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcTable( name: String, sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, + options: CaseInsensitiveStringMap, + paths: Seq[String], userSpecifiedSchema: Option[StructType]) - extends FileTable(sparkSession, fileIndex, userSpecifiedSchema) { - override def newScanBuilder(options: DataSourceOptions): OrcScanBuilder = + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def newScanBuilder(options: CaseInsensitiveStringMap): OrcScanBuilder = new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = OrcUtils.readSchema(sparkSession, files) - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = - new OrcWriteBuilder(options) + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = + new OrcWriteBuilder(options, paths) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala index 80429d91d5e4d..b1f8b8916a390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala @@ -25,10 +25,16 @@ import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFac import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcOutputWriter, OrcUtils} import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) + extends FileWriteBuilder( + options, + paths, + "orc", + supportsDataType = OrcDataSourceV2.supportsDataType) { -class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(options) { override def prepareWrite( sqlConf: SQLConf, job: Job, @@ -63,4 +69,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(optio } } } + + override def supportsDataType(dataType: DataType): Boolean = { + OrcDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 2ab7240556aaa..0c78cca086ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index a5203daea9cd0..d1105f0382f6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils /** * Grouped a iterator into batches. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 04623b1ab3c2f..3710218b2af5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -29,8 +29,9 @@ import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index e9cff1a5a2007..c598b7c671a42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils /** * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 1ce1215bfdd62..01ce07b133ffd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index b3d12f67b5d63..b679f163fc561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.streaming import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.SerializableConfiguration object FileStreamSink extends Logging { @@ -37,23 +39,54 @@ object FileStreamSink extends Logging { * Returns true if there is a single path that has a metadata log indicating which files should * be read. */ - def hasMetadata(path: Seq[String], hadoopConf: Configuration): Boolean = { + def hasMetadata(path: Seq[String], hadoopConf: Configuration, sqlConf: SQLConf): Boolean = { path match { case Seq(singlePath) => + val hdfsPath = new Path(singlePath) + val fs = hdfsPath.getFileSystem(hadoopConf) + if (fs.isDirectory(hdfsPath)) { + val metadataPath = new Path(hdfsPath, metadataDir) + checkEscapedMetadataPath(fs, metadataPath, sqlConf) + fs.exists(metadataPath) + } else { + false + } + case _ => false + } + } + + def checkEscapedMetadataPath(fs: FileSystem, metadataPath: Path, sqlConf: SQLConf): Unit = { + if (sqlConf.getConf(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED) + && StreamExecution.containsSpecialCharsInPath(metadataPath)) { + val legacyMetadataPath = new Path(metadataPath.toUri.toString) + val legacyMetadataPathExists = try { - val hdfsPath = new Path(singlePath) - val fs = hdfsPath.getFileSystem(hadoopConf) - if (fs.isDirectory(hdfsPath)) { - fs.exists(new Path(hdfsPath, metadataDir)) - } else { - false - } + fs.exists(legacyMetadataPath) } catch { case NonFatal(e) => - logWarning(s"Error while looking for metadata directory.") + // We may not have access to this directory. Don't fail the query if that happens. + logWarning(e.getMessage, e) false } - case _ => false + if (legacyMetadataPathExists) { + throw new SparkException( + s"""Error: we detected a possible problem with the location of your "_spark_metadata" + |directory and you likely need to move it before restarting this query. + | + |Earlier version of Spark incorrectly escaped paths when writing out the + |"_spark_metadata" directory for structured streaming. While this was corrected in + |Spark 3.0, it appears that your query was started using an earlier version that + |incorrectly handled the "_spark_metadata" path. + | + |Correct "_spark_metadata" Directory: $metadataPath + |Incorrect "_spark_metadata" Directory: $legacyMetadataPath + | + |Please move the data from the incorrect directory to the correct one, delete the + |incorrect directory, and then restart this query. If you believe you are receiving + |this message in error, you can disable it with the SQL conf + |${SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key}.""" + .stripMargin) + } } } @@ -92,11 +125,16 @@ class FileStreamSink( partitionColumnNames: Seq[String], options: Map[String, String]) extends Sink with Logging { + private val hadoopConf = sparkSession.sessionState.newHadoopConf() private val basePath = new Path(path) - private val logPath = new Path(basePath, FileStreamSink.metadataDir) + private val logPath = { + val metadataDir = new Path(basePath, FileStreamSink.metadataDir) + val fs = metadataDir.getFileSystem(hadoopConf) + FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sparkSession.sessionState.conf) + metadataDir + } private val fileLog = - new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) - private val hadoopConf = sparkSession.sessionState.newHadoopConf() + new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toString) private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 103fa7ce9066d..43b70ae0a51b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -208,7 +208,7 @@ class FileStreamSource( var allFiles: Seq[FileStatus] = null sourceHasMetadata match { case None => - if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf, sparkSession.sessionState.conf)) { sourceHasMetadata = Some(true) allFiles = allFilesUsingMetadataLogFileIndex() } else { @@ -220,7 +220,7 @@ class FileStreamSource( // double check whether source has metadata, preventing the extreme corner case that // metadata log and data files are only generated after the previous // `FileStreamSink.hasMetadata` check - if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf, sparkSession.sessionState.conf)) { sourceHasMetadata = Some(true) allFiles = allFilesUsingMetadataLogFileIndex() } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 3ff5b86ac45d6..a27898cb0c9fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} - /** * A simple offset for sources that produce a single linear stream of data. */ -case class LongOffset(offset: Long) extends OffsetV2 { +case class LongOffset(offset: Long) extends Offset { override val json = offset.toString @@ -37,14 +35,4 @@ object LongOffset { * @return new LongOffset */ def apply(offset: SerializedOffset) : LongOffset = new LongOffset(offset.json.toLong) - - /** - * Convert generic Offset to LongOffset if possible. - * @return converted LongOffset - */ - def convert(offset: Offset): Option[LongOffset] = offset match { - case lo: LongOffset => Some(lo) - case so: SerializedOffset => Some(LongOffset(so)) - case _ => None - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index 5cacdd070b735..80eed7b277216 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -39,10 +39,16 @@ class MetadataLogFileIndex( userSpecifiedSchema: Option[StructType]) extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) { - private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) + private val metadataDirectory = { + val metadataDir = new Path(path, FileStreamSink.metadataDir) + val fs = metadataDir.getFileSystem(sparkSession.sessionState.newHadoopConf()) + FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sparkSession.sessionState.conf) + metadataDir + } + logInfo(s"Reading streaming file log from $metadataDirectory") private val metadataLog = - new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toUri.toString) + new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toString) private val allFilesFromLog = metadataLog.allFiles().map(_.toFileStatus).filterNot(_.isDirectory) private var cachedPartitionSpec: PartitionSpec = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2c339759f95ba..7a3cdbc926446 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} @@ -26,11 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatch import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec} -import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateControlMicroBatchStream} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.streaming.sources.{RateControlMicroBatchStream, WriteToMicroBatchDataSource} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -39,7 +38,7 @@ class MicroBatchExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: BaseStreamingSink, + sink: Table, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -49,7 +48,7 @@ class MicroBatchExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + @volatile protected var sources: Seq[SparkDataStream] = Seq.empty private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -78,6 +77,7 @@ class MicroBatchExecution( val disabledSources = sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",") + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ val _logicalPlan = analyzedPlan.transform { case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { @@ -88,32 +88,33 @@ class MicroBatchExecution( logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]") StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(ds, dsName, table: SupportsMicroBatchRead, options, output, _) - if !disabledSources.contains(ds.getClass.getCanonicalName) => - v2ToRelationMap.getOrElseUpdate(s, { - // Materialize source to avoid creating it in every batch - val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - nextSourceId += 1 - logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]") - val dsOptions = new DataSourceOptions(options.asJava) - // TODO: operator pushdown. - val scan = table.newScanBuilder(dsOptions).build() - val stream = scan.toMicroBatchStream(metadataPath) - StreamingDataSourceV2Relation(output, scan, stream) - }) - case s @ StreamingRelationV2(ds, dsName, _, _, output, v1Relation) => - v2ToExecutionRelationMap.getOrElseUpdate(s, { - // Materialize source to avoid creating it in every batch - val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - if (v1Relation.isEmpty) { - throw new UnsupportedOperationException( - s"Data source $dsName does not support microbatch processing.") - } - val source = v1Relation.get.dataSource.createSource(metadataPath) - nextSourceId += 1 - logInfo(s"Using Source [$source] from DataSourceV2 named '$dsName' [$ds]") - StreamingExecutionRelation(source, output)(sparkSession) - }) + + case s @ StreamingRelationV2(src, srcName, table: SupportsRead, options, output, v1) => + val v2Disabled = disabledSources.contains(src.getClass.getCanonicalName) + if (!v2Disabled && table.supports(TableCapability.MICRO_BATCH_READ)) { + v2ToRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + nextSourceId += 1 + logInfo(s"Reading table [$table] from DataSourceV2 named '$srcName' [$src]") + // TODO: operator pushdown. + val scan = table.newScanBuilder(options).build() + val stream = scan.toMicroBatchStream(metadataPath) + StreamingDataSourceV2Relation(output, scan, stream) + }) + } else if (v1.isEmpty) { + throw new UnsupportedOperationException( + s"Data source $srcName does not support microbatch processing.") + } else { + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val source = v1.get.dataSource.createSource(metadataPath) + nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV2 named '$srcName' [$src]") + StreamingExecutionRelation(source, output)(sparkSession) + }) + } } sources = _logicalPlan.collect { // v1 source @@ -122,7 +123,15 @@ class MicroBatchExecution( case r: StreamingDataSourceV2Relation => r.stream } uniqueSources = sources.distinct - _logicalPlan + + // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. + sink match { + case s: SupportsWrite => + val streamingWrite = createStreamingWrite(s, extraOptions, _logicalPlan) + WriteToMicroBatchDataSource(streamingWrite, _logicalPlan) + + case _ => _logicalPlan + } } /** @@ -287,7 +296,7 @@ class MicroBatchExecution( * batch will be executed before getOffset is called again. */ availableOffsets.foreach { case (source: Source, end: Offset) => - val start = committedOffsets.get(source) + val start = committedOffsets.get(source).map(_.asInstanceOf[Offset]) source.getBatch(start, end) case nonV1Tuple => // The V2 API does not have the same edge case requiring getBatch to be called @@ -345,7 +354,7 @@ class MicroBatchExecution( if (isCurrentBatchConstructed) return true // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + val latestOffsets: Map[SparkDataStream, Option[OffsetV2]] = uniqueSources.map { case s: Source => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("getOffset") { @@ -402,7 +411,7 @@ class MicroBatchExecution( val prevBatchOff = offsetLog.get(currentBatchId - 1) if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { - case (src: Source, off) => src.commit(off) + case (src: Source, off: Offset) => src.commit(off) case (stream: MicroBatchStream, off) => stream.commit(stream.deserializeOffset(off.json)) case (src, _) => @@ -439,9 +448,9 @@ class MicroBatchExecution( // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { - case (source: Source, available) + case (source: Source, available: Offset) if committedOffsets.get(source).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(source) + val current = committedOffsets.get(source).map(_.asInstanceOf[Offset]) val batch = source.getBatch(current, available) assert(batch.isStreaming, s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + @@ -513,13 +522,8 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamingWriteSupportProvider => - val writer = s.createStreamingWriteSupport( - s"$runId", - newAttributePlan.schema, - outputMode, - new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, writer), newAttributePlan) + case _: SupportsWrite => + newAttributePlan.asInstanceOf[WriteToMicroBatchDataSource].createPlan(currentBatchId) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -549,7 +553,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamingWriteSupportProvider => + case _: SupportsWrite => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java deleted file mode 100644 index 43ad4b3384ec3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming; - -/** - * This is an internal, deprecated interface. New source implementations should use the - * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be - * supported in the long term. - * - * This class will be removed in a future release. - */ -public abstract class Offset { - /** - * A JSON-serialized representation of an Offset that is - * used for saving offsets to the offset log. - * Note: We assume that equivalent/equal offsets serialize to - * identical JSON strings. - * - * @return JSON string encoding - */ - public abstract String json(); - - /** - * Equality based on JSON string representation. We leverage the - * JSON representation for normalization between the Offset's - * in memory and on disk representations. - */ - @Override - public boolean equals(Object obj) { - if (obj instanceof Offset) { - return this.json().equals(((Offset) obj).json()); - } else { - return false; - } - } - - @Override - public int hashCode() { - return this.json().hashCode(); - } - - @Override - public String toString() { - return this.json(); - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73cf355dbe758..b6fa2e9dc3612 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -24,13 +24,15 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} + /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance * vector clock that must progress linearly forward. */ -case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMetadata] = None) { +case class OffsetSeq(offsets: Seq[Option[OffsetV2]], metadata: Option[OffsetSeqMetadata] = None) { /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of @@ -39,7 +41,7 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * This method is typically used to associate a serialized offset with actual sources (which * cannot be serialized). */ - def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { + def toStreamProgress(sources: Seq[SparkDataStream]): StreamProgress = { assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + s"Cannot continue.") @@ -56,13 +58,13 @@ object OffsetSeq { * Returns a [[OffsetSeq]] with a variable sequence of offsets. * `nulls` in the sequence are converted to `None`s. */ - def fill(offsets: Offset*): OffsetSeq = OffsetSeq.fill(None, offsets: _*) + def fill(offsets: OffsetV2*): OffsetSeq = OffsetSeq.fill(None, offsets: _*) /** * Returns a [[OffsetSeq]] with metadata and a variable sequence of offsets. * `nulls` in the sequence are converted to `None`s. */ - def fill(metadata: Option[String], offsets: Offset*): OffsetSeq = { + def fill(metadata: Option[String], offsets: OffsetV2*): OffsetSeq = { OffsetSeq(offsets.map(Option(_)), metadata.map(OffsetSeqMetadata.apply)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index 2c8d7c7b0f3c5..8a05dade092c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets._ import scala.io.{Source => IOSource} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} /** * This class is used to log offsets to persistent files in HDFS. @@ -47,7 +48,7 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) override protected def deserialize(in: InputStream): OffsetSeq = { // called inside a try-finally where the underlying stream is closed in the caller - def parseOffset(value: String): Offset = value match { + def parseOffset(value: String): OffsetV2 = value match { case OffsetSeqLog.SERIALIZED_VOID_OFFSET => null case json => SerializedOffset(json) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 25283515b882f..932daef8965d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2Relation, StreamWriterCommitProgress} -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream +import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, SparkDataStream} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -44,7 +45,7 @@ import org.apache.spark.util.Clock trait ProgressReporter extends Logging { case class ExecutionStats( - inputRows: Map[BaseStreamingSource, Long], + inputRows: Map[SparkDataStream, Long], stateOperators: Seq[StateOperatorProgress], eventTimeStats: Map[String, String]) @@ -55,10 +56,10 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[BaseStreamingSource, LogicalPlan] + protected def newData: Map[SparkDataStream, LogicalPlan] protected def sinkCommitProgress: Option[StreamWriterCommitProgress] - protected def sources: Seq[BaseStreamingSource] - protected def sink: BaseStreamingSink + protected def sources: Seq[SparkDataStream] + protected def sink: Table protected def offsetSeqMetadata: OffsetSeqMetadata protected def currentBatchId: Long protected def sparkSession: SparkSession @@ -67,8 +68,8 @@ trait ProgressReporter extends Logging { // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L private var currentTriggerEndTimestamp = -1L - private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _ - private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _ + private var currentTriggerStartOffsets: Map[SparkDataStream, String] = _ + private var currentTriggerEndOffsets: Map[SparkDataStream, String] = _ // TODO: Restore this from the checkpoint when possible. private var lastTriggerStartTimestamp = -1L @@ -240,9 +241,9 @@ trait ProgressReporter extends Logging { } /** Extract number of input sources for each streaming source in plan */ - private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + private def extractSourceToNumInputRows(): Map[SparkDataStream, Long] = { - def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + def sumRows(tuples: Seq[(SparkDataStream, Long)]): Map[SparkDataStream, Long] = { tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } @@ -262,7 +263,7 @@ trait ProgressReporter extends Logging { val sourceToInputRowsTuples = lastExecution.executedPlan.collect { case s: MicroBatchScanExec => val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = s.stream.asInstanceOf[BaseStreamingSource] + val source = s.stream source -> numRows } logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala index 34bc085d920c1..190325fb7ec25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -17,14 +17,21 @@ package org.apache.spark.sql.execution.streaming +import java.util + import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.sources.v2.{Table, TableCapability} +import org.apache.spark.sql.types.StructType /** * An interface for systems that can collect the results of a streaming query. In order to preserve * exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same * batch. + * + * Note that, we extends `Table` here, to make the v1 streaming sink API be compatible with + * data source v2. */ -trait Sink extends BaseStreamingSink { +trait Sink extends Table { /** * Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if @@ -38,4 +45,16 @@ trait Sink extends BaseStreamingSink { * after data is consumed by sink successfully. */ def addBatch(batchId: Long, data: DataFrame): Unit + + override def name: String = { + throw new IllegalStateException("should not be called.") + } + + override def schema: StructType = { + throw new IllegalStateException("should not be called.") + } + + override def capabilities: util.Set[TableCapability] = { + throw new IllegalStateException("should not be called.") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index dbbd59e06909c..7f66d0b055cc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -18,14 +18,19 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream import org.apache.spark.sql.types.StructType /** * A source of continually arriving data for a streaming query. A [[Source]] must have a * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. + * + * Note that, we extends `SparkDataStream` here, to make the v1 streaming source API be compatible + * with data source v2. */ -trait Source extends BaseStreamingSource { +trait Source extends SparkDataStream { /** Returns the schema of the data from this source */ def schema: StructType @@ -62,6 +67,15 @@ trait Source extends BaseStreamingSource { */ def commit(end: Offset) : Unit = {} - /** Stop this source and free any resources it has allocated. */ - def stop(): Unit + override def initialOffset(): OffsetV2 = { + throw new IllegalStateException("should not be called.") + } + + override def deserializeOffset(json: String): OffsetV2 = { + throw new IllegalStateException("should not be called.") + } + + override def commit(end: OffsetV2): Unit = { + throw new IllegalStateException("should not be called.") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 90f7b477103ae..4c08b3aa78666 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -24,6 +24,7 @@ import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.{Condition, ReentrantLock} +import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} import scala.util.control.NonFatal @@ -34,11 +35,17 @@ import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table} +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} +import org.apache.spark.sql.sources.v2.writer.SupportsTruncate +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} /** States for [[StreamExecution]]'s lifecycle. */ @@ -55,14 +62,15 @@ case object RECONFIGURING extends State * and the results are committed transactionally to the given [[Sink]]. * * @param deleteCheckpointOnStop whether to delete the checkpoint if the query is stopped without - * errors + * errors. Checkpoint deletion can be forced with the appropriate + * Spark configuration. */ abstract class StreamExecution( override val sparkSession: SparkSession, override val name: String, private val checkpointRoot: String, analyzedPlan: LogicalPlan, - val sink: BaseStreamingSink, + val sink: Table, val trigger: Trigger, val triggerClock: Clock, val outputMode: OutputMode, @@ -89,9 +97,47 @@ abstract class StreamExecution( val resolvedCheckpointRoot = { val checkpointPath = new Path(checkpointRoot) val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) - fs.mkdirs(checkpointPath) - checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString + if (sparkSession.conf.get(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED) + && StreamExecution.containsSpecialCharsInPath(checkpointPath)) { + // In Spark 2.4 and earlier, the checkpoint path is escaped 3 times (3 `Path.toUri.toString` + // calls). If this legacy checkpoint path exists, we will throw an error to tell the user how + // to migrate. + val legacyCheckpointDir = + new Path(new Path(checkpointPath.toUri.toString).toUri.toString).toUri.toString + val legacyCheckpointDirExists = + try { + fs.exists(new Path(legacyCheckpointDir)) + } catch { + case NonFatal(e) => + // We may not have access to this directory. Don't fail the query if that happens. + logWarning(e.getMessage, e) + false + } + if (legacyCheckpointDirExists) { + throw new SparkException( + s"""Error: we detected a possible problem with the location of your checkpoint and you + |likely need to move it before restarting this query. + | + |Earlier version of Spark incorrectly escaped paths when writing out checkpoints for + |structured streaming. While this was corrected in Spark 3.0, it appears that your + |query was started using an earlier version that incorrectly handled the checkpoint + |path. + | + |Correct Checkpoint Directory: $checkpointPath + |Incorrect Checkpoint Directory: $legacyCheckpointDir + | + |Please move the data from the incorrect directory to the correct one, delete the + |incorrect directory, and then restart this query. If you believe you are receiving + |this message in error, you can disable it with the SQL conf + |${SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key}.""" + .stripMargin) + } + } + val checkpointDir = checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + fs.mkdirs(checkpointDir) + checkpointDir.toString } + logInfo(s"Checkpoint root $checkpointRoot resolved to $resolvedCheckpointRoot.") def logicalPlan: LogicalPlan @@ -160,7 +206,7 @@ abstract class StreamExecution( /** * A list of unique sources in the query plan. This will be set when generating logical plan. */ - @volatile protected var uniqueSources: Seq[BaseStreamingSource] = Seq.empty + @volatile protected var uniqueSources: Seq[SparkDataStream] = Seq.empty /** Defines the internal state of execution */ protected val state = new AtomicReference[State](INITIALIZING) @@ -169,7 +215,7 @@ abstract class StreamExecution( var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[BaseStreamingSource, LogicalPlan] = _ + protected var newData: Map[SparkDataStream, LogicalPlan] = _ @volatile protected var streamDeathCause: StreamingQueryException = null @@ -225,7 +271,7 @@ abstract class StreamExecution( /** Returns the path of a file with `name` in the checkpoint directory. */ protected def checkpointFile(name: String): String = - new Path(new Path(resolvedCheckpointRoot), name).toUri.toString + new Path(new Path(resolvedCheckpointRoot), name).toString /** * Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]] @@ -335,10 +381,13 @@ abstract class StreamExecution( postEvent( new QueryTerminatedEvent(id, runId, exception.map(_.cause).map(Utils.exceptionString))) - // Delete the temp checkpoint only when the query didn't fail - if (deleteCheckpointOnStop && exception.isEmpty) { + // Delete the temp checkpoint when either force delete enabled or the query didn't fail + if (deleteCheckpointOnStop && + (sparkSession.sessionState.conf + .getConf(SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION) || exception.isEmpty)) { val checkpointPath = new Path(resolvedCheckpointRoot) try { + logInfo(s"Deleting checkpoint $checkpointPath.") val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) fs.delete(checkpointPath, true) } catch { @@ -389,7 +438,7 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset, timeoutMs: Long): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: OffsetV2, timeoutMs: Long): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets @@ -532,6 +581,35 @@ abstract class StreamExecution( Option(name).map(_ + "
    ").getOrElse("") + s"id = $id
    runId = $runId
    batch = $batchDescription" } + + protected def createStreamingWrite( + table: SupportsWrite, + options: Map[String, String], + inputPlan: LogicalPlan): StreamingWrite = { + val writeBuilder = table.newWriteBuilder(new CaseInsensitiveStringMap(options.asJava)) + .withQueryId(id.toString) + .withInputDataSchema(inputPlan.schema) + outputMode match { + case Append => + writeBuilder.buildForStreaming() + + case Complete => + // TODO: we should do this check earlier when we have capability API. + require(writeBuilder.isInstanceOf[SupportsTruncate], + table.name + " does not support Complete mode.") + writeBuilder.asInstanceOf[SupportsTruncate].truncate().buildForStreaming() + + case Update => + // Although no v2 sinks really support Update mode now, but during tests we do want them + // to pretend to support Update mode, and treat Update mode same as Append mode. + if (Utils.isTesting) { + writeBuilder.buildForStreaming() + } else { + throw new IllegalArgumentException( + "Data source v2 streaming sinks does not support Update mode.") + } + } + } } object StreamExecution { @@ -568,6 +646,11 @@ object StreamExecution { case _ => false } + + /** Whether the path contains special chars that will be escaped when converting to a `URI`. */ + def containsSpecialCharsInPath(path: Path): Boolean = { + path.toUri.getPath != new Path(path.toUri.toString).toUri.getPath + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 8531070b1bc49..8783eaa0e68b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -19,32 +19,35 @@ package org.apache.spark.sql.execution.streaming import scala.collection.{immutable, GenTraversableOnce} +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} + + /** * A helper class that looks like a Map[Source, Offset]. */ class StreamProgress( - val baseMap: immutable.Map[BaseStreamingSource, Offset] = - new immutable.HashMap[BaseStreamingSource, Offset]) - extends scala.collection.immutable.Map[BaseStreamingSource, Offset] { + val baseMap: immutable.Map[SparkDataStream, OffsetV2] = + new immutable.HashMap[SparkDataStream, OffsetV2]) + extends scala.collection.immutable.Map[SparkDataStream, OffsetV2] { - def toOffsetSeq(source: Seq[BaseStreamingSource], metadata: OffsetSeqMetadata): OffsetSeq = { + def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } override def toString: String = baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") - override def +[B1 >: Offset](kv: (BaseStreamingSource, B1)): Map[BaseStreamingSource, B1] = { + override def +[B1 >: OffsetV2](kv: (SparkDataStream, B1)): Map[SparkDataStream, B1] = { baseMap + kv } - override def get(key: BaseStreamingSource): Option[Offset] = baseMap.get(key) + override def get(key: SparkDataStream): Option[OffsetV2] = baseMap.get(key) - override def iterator: Iterator[(BaseStreamingSource, Offset)] = baseMap.iterator + override def iterator: Iterator[(SparkDataStream, OffsetV2)] = baseMap.iterator - override def -(key: BaseStreamingSource): Map[BaseStreamingSource, Offset] = baseMap - key + override def -(key: SparkDataStream): Map[SparkDataStream, OffsetV2] = baseMap - key - def ++(updates: GenTraversableOnce[(BaseStreamingSource, Offset)]): StreamProgress = { + def ++(updates: GenTraversableOnce[(SparkDataStream, OffsetV2)]): StreamProgress = { new StreamProgress(baseMap ++ updates) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 83d38dcade7e6..142b6e7d18068 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{DataSourceV2, Table} +import org.apache.spark.sql.sources.v2.{Table, TableProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream +import org.apache.spark.sql.util.CaseInsensitiveStringMap object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -62,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ case class StreamingExecutionRelation( - source: BaseStreamingSource, + source: SparkDataStream, output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { @@ -86,16 +88,16 @@ case class StreamingExecutionRelation( // know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** - * Used to link a [[DataSourceV2]] into a streaming + * Used to link a [[TableProvider]] into a streaming * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]], * and should be converted before passing to [[StreamExecution]]. */ case class StreamingRelationV2( - dataSource: DataSourceV2, + source: TableProvider, sourceName: String, table: Table, - extraOptions: Map[String, String], + extraOptions: CaseInsensitiveStringMap, output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9c5c16f4f5d13..9ae39c79c5156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,30 +17,30 @@ package org.apache.spark.sql.execution.streaming +import java.util + +import scala.collection.JavaConverters._ + import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport +import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.writer.{SupportsTruncate, WriteBuilder} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) extends BaseRelation { override def schema: StructType = data.schema } -class ConsoleSinkProvider extends DataSourceV2 - with StreamingWriteSupportProvider +class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new ConsoleWriteSupport(schema, options) + override def getTable(options: CaseInsensitiveStringMap): Table = { + ConsoleTable } def createRelation( @@ -60,3 +60,33 @@ class ConsoleSinkProvider extends DataSourceV2 def shortName(): String = "console" } + +object ConsoleTable extends Table with SupportsWrite { + + override def name(): String = "console" + + override def schema(): StructType = StructType(Nil) + + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.STREAMING_WRITE).asJava + } + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + new WriteBuilder with SupportsTruncate { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + // Do nothing for truncate. Console sink is special that it just prints all the records. + override def truncate(): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + assert(inputSchema != null) + new ConsoleWrite(inputSchema, options) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index b22795d207760..5475becc5bff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming.continuous import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import java.util.function.UnaryOperator -import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.SparkEnv @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider, SupportsContinuousRead} +import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -42,7 +42,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamingWriteSupportProvider, + sink: SupportsWrite, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -57,26 +57,29 @@ class ContinuousExecution( // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ - override val logicalPlan: LogicalPlan = { + // Throwable that caused the execution to fail + private val failure: AtomicReference[Throwable] = new AtomicReference[Throwable](null) + + override val logicalPlan: WriteToContinuousDataSource = { val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]() var nextSourceId = 0 + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ val _logicalPlan = analyzedPlan.transform { - case s @ StreamingRelationV2( - ds, dsName, table: SupportsContinuousRead, options, output, _) => + case s @ StreamingRelationV2(ds, sourceName, table: SupportsRead, options, output, _) => + if (!table.supports(TableCapability.CONTINUOUS_READ)) { + throw new UnsupportedOperationException( + s"Data source $sourceName does not support continuous processing.") + } + v2ToRelationMap.getOrElseUpdate(s, { val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]") - val dsOptions = new DataSourceOptions(options.asJava) + logInfo(s"Reading table [$table] from DataSourceV2 named '$sourceName' [$ds]") // TODO: operator pushdown. - val scan = table.newScanBuilder(dsOptions).build() + val scan = table.newScanBuilder(options).build() val stream = scan.toContinuousStream(metadataPath) StreamingDataSourceV2Relation(output, scan, stream) }) - - case StreamingRelationV2(_, sourceName, _, _, _, _) => - throw new UnsupportedOperationException( - s"Data source $sourceName does not support continuous processing.") } sources = _logicalPlan.collect { @@ -84,7 +87,9 @@ class ContinuousExecution( } uniqueSources = sources.distinct - _logicalPlan + // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. + WriteToContinuousDataSource( + createStreamingWrite(sink, extraOptions, _logicalPlan), _logicalPlan) } private val triggerExecutor = trigger match { @@ -174,17 +179,10 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamingWriteSupport( - s"$runId", - withNewSources.schema, - outputMode, - new DataSourceOptions(extraOptions.asJava)) - val planWithSink = WriteToContinuousDataSource(writer, withNewSources) - reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionForQuery, - planWithSink, + withNewSources, outputMode, checkpointFile("state"), id, @@ -194,7 +192,7 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - val stream = planWithSink.collect { + val stream = withNewSources.collect { case relation: StreamingDataSourceV2Relation => relation.stream.asInstanceOf[ContinuousStream] }.head @@ -214,9 +212,14 @@ class ContinuousExecution( trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. - val epochEndpoint = - EpochCoordinatorRef.create( - writer, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + val epochEndpoint = EpochCoordinatorRef.create( + logicalPlan.write, + stream, + this, + epochCoordinatorId, + currentBatchId, + sparkSession, + SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { @@ -258,19 +261,40 @@ class ContinuousExecution( lastExecution.toRdd } } + + val f = failure.get() + if (f != null) { + throw f + } } catch { case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) && state.get() == RECONFIGURING => logInfo(s"Query $id ignoring exception from reconfiguring: $t") // interrupted by reconfiguration - swallow exception so we can restart the query } finally { - epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) - SparkEnv.get.rpcEnv.stop(epochEndpoint) - - epochUpdateThread.interrupt() - epochUpdateThread.join() - - sparkSession.sparkContext.cancelJobGroup(runId.toString) + // The above execution may finish before getting interrupted, for example, a Spark job having + // 0 partitions will complete immediately. Then the interrupted status will sneak here. + // + // To handle this case, we do the two things here: + // + // 1. Clean up the resources in `queryExecutionThread.runUninterruptibly`. This may increase + // the waiting time of `stop` but should be minor because the operations here are very fast + // (just sending an RPC message in the same process and stopping a very simple thread). + // 2. Clear the interrupted status at the end so that it won't impact the `runContinuous` + // call. We may clear the interrupted status set by `stop`, but it doesn't affect the query + // termination because `runActivatedStream` will check `state` and exit accordingly. + queryExecutionThread.runUninterruptibly { + try { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) + } finally { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochUpdateThread.interrupt() + epochUpdateThread.join() + // The following line must be the last line because it may fail if SparkContext is stopped + sparkSession.sparkContext.cancelJobGroup(runId.toString) + } + } + Thread.interrupted() } } @@ -370,6 +394,35 @@ class ContinuousExecution( } } + /** + * Stores error and stops the query execution thread to terminate the query in new thread. + */ + def stopInNewThread(error: Throwable): Unit = { + if (failure.compareAndSet(null, error)) { + logError(s"Query $prettyIdString received exception $error") + stopInNewThread() + } + } + + /** + * Stops the query execution thread to terminate the query in new thread. + */ + private def stopInNewThread(): Unit = { + new Thread("stop-continuous-execution") { + setDaemon(true) + + override def run(): Unit = { + try { + ContinuousExecution.this.stop() + } catch { + case e: Throwable => + logError(e.getMessage, e) + throw e + } + } + }.start() + } + /** * Stops the query execution thread to terminate the query. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 48ff70f9c9d07..d55f71c7be830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -23,17 +23,13 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousStream( - rowsPerSecond: Long, - numPartitions: Int, - options: DataSourceOptions) extends ContinuousStream { +class RateStreamContinuousStream(rowsPerSecond: Long, numPartitions: Int) extends ContinuousStream { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index e7bc71394061e..2263b42870a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -34,9 +34,9 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.{Offset => _, _} import org.apache.spark.sql.execution.streaming.sources.TextSocketReader -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.RpcUtils @@ -49,7 +49,7 @@ import org.apache.spark.util.RpcUtils * buckets and serves the messages to the executors via a RPC endpoint. */ class TextSocketContinuousStream( - host: String, port: Int, numPartitions: Int, options: DataSourceOptions) + host: String, port: Int, numPartitions: Int, options: CaseInsensitiveStringMap) extends ContinuousStream with Logging { implicit val defaultFormats: DefaultFormats = DefaultFormats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index d1bda79f4b6ef..decf524f7167c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeR import org.apache.spark.sql.SparkSession import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,7 +82,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writeSupport: StreamingWriteSupport, + writeSupport: StreamingWrite, stream: ContinuousStream, query: ContinuousExecution, epochCoordinatorId: String, @@ -115,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writeSupport: StreamingWriteSupport, + writeSupport: StreamingWrite, stream: ContinuousStream, query: ContinuousExecution, startEpoch: Long, @@ -123,6 +123,9 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private val epochBacklogQueueSize = + session.sqlContext.conf.continuousStreamingEpochBacklogQueueSize + private var queryWritesStopped: Boolean = false private var numReaderPartitions: Int = _ @@ -212,6 +215,7 @@ private[continuous] class EpochCoordinator( if (!partitionCommits.isDefinedAt((epoch, partitionId))) { partitionCommits.put((epoch, partitionId), message) resolveCommitsAtEpoch(epoch) + checkProcessingQueueBoundaries() } case ReportPartitionOffset(partitionId, epoch, offset) => @@ -223,6 +227,22 @@ private[continuous] class EpochCoordinator( query.addOffset(epoch, stream, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } + checkProcessingQueueBoundaries() + } + + private def checkProcessingQueueBoundaries() = { + if (partitionOffsets.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the partition offset queue has " + + "exceeded its maximum")) + } + if (partitionCommits.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the partition commit queue has " + + "exceeded its maximum")) + } + if (epochsWaitingToBeCommitted.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the epoch queue has " + + "exceeded its maximum")) + } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 7ad21cc304e7c..54f484c4adae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite /** * The logical plan for writing data in a continuous stream. */ -case class WriteToContinuousDataSource( - writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { +case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan) + extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index 2178466d63142..2f3af6a6544c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,22 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite /** - * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. + * The physical plan for writing data into a continuous processing [[StreamingWrite]]. */ -case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) - extends UnaryExecNode with Logging { +case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan) + extends UnaryExecNode with Logging { + override def child: SparkPlan = query override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writeSupport.createStreamingWriterFactory() + val writerFactory = write.createStreamingWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source write support: $writeSupport. " + + logInfo(s"Start processing data source write support: $write. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index e71f81caeb974..df149552dfb30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,27 +17,26 @@ package org.apache.spark.sql.execution.streaming +import java.util import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.util.control.NonFatal +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap object MemoryStream { protected val currentBlockId = new AtomicInteger(0) @@ -50,7 +49,7 @@ object MemoryStream { /** * A base class for memory stream implementations. Supports adding data and resetting. */ -abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { +abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream { val encoder = encoderFor[A] protected val attributes = encoder.schema.toAttributes @@ -62,10 +61,12 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas Dataset.ofRows(sqlContext.sparkSession, logicalPlan) } - def addData(data: A*): Offset = { + def addData(data: A*): OffsetV2 = { addData(data.toTraversable) } + def addData(data: TraversableOnce[A]): OffsetV2 + def fullSchema(): StructType = encoder.schema protected val logicalPlan: LogicalPlan = { @@ -73,30 +74,43 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas MemoryStreamTableProvider, "memory", new MemoryStreamTable(this), - Map.empty, + CaseInsensitiveStringMap.empty(), attributes, None)(sqlContext.sparkSession) } - def addData(data: TraversableOnce[A]): Offset + override def initialOffset(): OffsetV2 = { + throw new IllegalStateException("should not be called.") + } + + override def deserializeOffset(json: String): OffsetV2 = { + throw new IllegalStateException("should not be called.") + } + + override def commit(end: OffsetV2): Unit = { + throw new IllegalStateException("should not be called.") + } } // This class is used to indicate the memory stream data source. We don't actually use it, as // memory stream is for test only and we never look it up by name. object MemoryStreamTableProvider extends TableProvider { - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new IllegalStateException("MemoryStreamTableProvider should not be used.") } } -class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table - with SupportsMicroBatchRead with SupportsContinuousRead { +class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table with SupportsRead { override def name(): String = "MemoryStreamDataSource" override def schema(): StructType = stream.fullSchema() - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MemoryStreamScanBuilder(stream) } } @@ -212,22 +226,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } override def commit(end: OffsetV2): Unit = synchronized { - def check(newOffset: LongOffset): Unit = { - val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt - - if (offsetDiff < 0) { - sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") - } + val newOffset = end.asInstanceOf[LongOffset] + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt - batches.trimStart(offsetDiff) - lastOffsetCommitted = newOffset + if (offsetDiff < 0) { + sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") } - LongOffset.convert(end) match { - case Some(lo) => check(lo) - case None => sys.error(s"MemoryStream.commit() received an offset ($end) " + - "that did not originate with an instance of this class") - } + batches.trimStart(offsetDiff) + lastOffsetCommitted = newOffset } override def stop() {} @@ -262,93 +269,3 @@ object MemoryStreamReaderFactory extends PartitionReaderFactory { } } } - -/** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink { - def allData: Seq[Row] - def latestBatchData: Seq[Row] - def dataSinceBatch(sinceBatchId: Long): Seq[Row] - def latestBatchId: Option[Long] -} - -/** - * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit - * tests and does not provide durability. - */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink - with MemorySinkBase with Logging { - - private case class AddedData(batchId: Long, data: Array[Row]) - - /** An order list of batches that have been written to this [[Sink]]. */ - @GuardedBy("this") - private val batches = new ArrayBuffer[AddedData]() - - /** Returns all rows that are stored in this [[Sink]]. */ - def allData: Seq[Row] = synchronized { - batches.flatMap(_.data) - } - - def latestBatchId: Option[Long] = synchronized { - batches.lastOption.map(_.batchId) - } - - def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } - - def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { - batches.filter(_.batchId > sinceBatchId).flatMap(_.data) - } - - def toDebugString: String = synchronized { - batches.map { case AddedData(batchId, data) => - val dataStr = try data.mkString(" ") catch { - case NonFatal(e) => "[Error converting to string]" - } - s"$batchId: $dataStr" - }.mkString("\n") - } - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - val notCommitted = synchronized { - latestBatchId.isEmpty || batchId > latestBatchId.get - } - if (notCommitted) { - logDebug(s"Committing batch $batchId to $this") - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collect()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") - } - } else { - logDebug(s"Skipping already committed batch: $batchId") - } - } - - def clear(): Unit = synchronized { - batches.clear() - } - - override def toString(): String = "MemorySink" -} - -/** - * Used to query the data that has been written into a [[MemorySink]]. - */ -case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { - def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - - private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) - - override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala index 833e62f35ede1..dbe242784986d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** Common methods used to create writes for the the console sink */ -class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) - extends StreamingWriteSupport with Logging { +class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap) + extends StreamingWrite with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 4218fd51ad206..6da1b3a49c442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -17,68 +17,88 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util + +import scala.collection.JavaConverters._ + import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** - * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified - * [[ForeachWriter]]. + * A write-only table for forwarding data into the specified [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. * @param converter An object to convert internal rows to target type T. Either it can be * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriteSupportProvider[T]( +case class ForeachWriterTable[T]( writer: ForeachWriter[T], converter: Either[ExpressionEncoder[T], InternalRow => T]) - extends StreamingWriteSupportProvider { - - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new StreamingWriteSupport { - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - - override def createStreamingWriterFactory(): StreamingDataWriterFactory = { - val rowConverter: InternalRow => T = converter match { - case Left(enc) => - val boundEnc = enc.resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - boundEnc.fromRow - case Right(func) => - func - } - ForeachWriterFactory(writer, rowConverter) + extends Table with SupportsWrite { + + override def name(): String = "ForeachSink" + + override def schema(): StructType = StructType(Nil) + + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.STREAMING_WRITE).asJava + } + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + new WriteBuilder with SupportsTruncate { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this } - override def toString: String = "ForeachSink" + // Do nothing for truncate. Foreach sink is special that it just forwards all the records to + // ForeachWriter. + override def truncate(): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + new StreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + inputSchema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) + } + } + } } } } -object ForeachWriteSupportProvider { +object ForeachWriterTable { def apply[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { + encoder: ExpressionEncoder[T]): ForeachWriterTable[_] = { writer match { case pythonWriter: PythonForeachWriter => - new ForeachWriteSupportProvider[UnsafeRow]( + new ForeachWriterTable[UnsafeRow]( pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) case _ => - new ForeachWriteSupportProvider[T](writer, Left(encoder)) + new ForeachWriterTable[T](writer, Left(encoder)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala index 143235efee81d..f3951897ea747 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} /** * A [[BatchWrite]] used to hook V2 stream writers into a microbatch plan. It implements * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped * streaming write support. */ -class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWriteSupport) extends BatchWrite { +class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWrite) extends BatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = { writeSupport.commit(eppchId, messages) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala index a8feed34b96dc..5403eafd54b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{ManualClock, SystemClock} class RateStreamMicroBatchStream( @@ -38,7 +38,7 @@ class RateStreamMicroBatchStream( // The default values here are used in tests. rampUpTimeSeconds: Long = 0, numPartitions: Int = 1, - options: DataSourceOptions, + options: CaseInsensitiveStringMap, checkpointLocation: String) extends MicroBatchStream with Logging { import RateStreamProvider._ @@ -155,7 +155,7 @@ class RateStreamMicroBatchStream( override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" + s"numPartitions=${options.getOrDefault(NUM_PARTITIONS, "default")}" } case class RateStreamMicroBatchInputPartition( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 075c6b9362ba2..8dbae9f787cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util + +import scala.collection.JavaConverters._ + import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousStream @@ -25,6 +29,7 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A source that generates increment long values with timestamps. Each generated row has two @@ -40,18 +45,17 @@ import org.apache.spark.sql.types._ * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ -class RateStreamProvider extends DataSourceV2 - with TableProvider with DataSourceRegister { +class RateStreamProvider extends TableProvider with DataSourceRegister { import RateStreamProvider._ - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { val rowsPerSecond = options.getLong(ROWS_PER_SECOND, 1) if (rowsPerSecond <= 0) { throw new IllegalArgumentException( s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") } - val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME).orElse(null)) + val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME)) .map(JavaUtils.timeStringAsSec) .getOrElse(0L) if (rampUpTimeSeconds < 0) { @@ -75,7 +79,7 @@ class RateStreamTable( rowsPerSecond: Long, rampUpTimeSeconds: Long, numPartitions: Int) - extends Table with SupportsMicroBatchRead with SupportsContinuousRead { + extends Table with SupportsRead { override def name(): String = { s"RateStream(rowsPerSecond=$rowsPerSecond, rampUpTimeSeconds=$rampUpTimeSeconds, " + @@ -84,7 +88,11 @@ class RateStreamTable( override def schema(): StructType = RateStreamProvider.SCHEMA - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { override def readSchema(): StructType = RateStreamProvider.SCHEMA @@ -94,7 +102,7 @@ class RateStreamTable( } override def toContinuousStream(checkpointLocation: String): ContinuousStream = { - new RateStreamContinuousStream(rowsPerSecond, numPartitions, options) + new RateStreamContinuousStream(rowsPerSecond, numPartitions) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala index 540131c8de8a1..dd8d89238008e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala @@ -29,7 +29,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.LongOffset -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} import org.apache.spark.unsafe.types.UTF8String @@ -39,8 +38,7 @@ import org.apache.spark.unsafe.types.UTF8String * and debugging. This MicroBatchReadSupport will *not* work in production applications due to * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchStream( - host: String, port: Int, numPartitions: Int, options: DataSourceOptions) +class TextSocketMicroBatchStream(host: String, port: Int, numPartitions: Int) extends MicroBatchStream with Logging { @GuardedBy("this") @@ -155,10 +153,7 @@ class TextSocketMicroBatchStream( } override def commit(end: Offset): Unit = synchronized { - val newOffset = LongOffset.convert(end).getOrElse( - sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " + - s"originate with an instance of this class") - ) + val newOffset = end.asInstanceOf[LongOffset] val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index c3b24a8f65dd9..e714859c16ddd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution.streaming.sources import java.text.SimpleDateFormat +import java.util import java.util.Locale +import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging @@ -30,21 +32,21 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap -class TextSocketSourceProvider extends DataSourceV2 - with TableProvider with DataSourceRegister with Logging { +class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { - private def checkParameters(params: DataSourceOptions): Unit = { + private def checkParameters(params: CaseInsensitiveStringMap): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") - if (!params.get("host").isPresent) { + if (!params.containsKey("host")) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } - if (!params.get("port").isPresent) { + if (!params.containsKey("port")) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } Try { - params.get("includeTimestamp").orElse("false").toBoolean + params.getBoolean("includeTimestamp", false) } match { case Success(_) => case Failure(_) => @@ -52,10 +54,10 @@ class TextSocketSourceProvider extends DataSourceV2 } } - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { checkParameters(options) new TextSocketTable( - options.get("host").get, + options.get("host"), options.getInt("port", -1), options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism), options.getBoolean("includeTimestamp", false)) @@ -66,7 +68,7 @@ class TextSocketSourceProvider extends DataSourceV2 } class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimestamp: Boolean) - extends Table with SupportsMicroBatchRead with SupportsContinuousRead { + extends Table with SupportsRead { override def name(): String = s"Socket[$host:$port]" @@ -78,12 +80,16 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest } } - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { override def readSchema(): StructType = schema() override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { - new TextSocketMicroBatchStream(host, port, numPartitions, options) + new TextSocketMicroBatchStream(host, port, numPartitions) } override def toContinuousStream(checkpointLocation: String): ContinuousStream = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala new file mode 100644 index 0000000000000..a3f58fa966fe8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite + +/** + * The logical plan for writing data to a micro-batch stream. + * + * Note that this logical plan does not have a corresponding physical plan, as it will be converted + * to [[WriteToDataSourceV2]] with [[MicroBatchWrite]] before execution. + */ +case class WriteToMicroBatchDataSource(write: StreamingWrite, query: LogicalPlan) + extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + def createPlan(batchId: Long): WriteToDataSourceV2 = { + WriteToDataSourceV2(new MicroBatchWrite(batchId, write), query) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala similarity index 69% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index c50dc7bcb8da1..de8d00d4ac348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -30,27 +32,46 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider - with MemorySinkBase with Logging { +class MemorySink extends Table with SupportsWrite with Logging { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new MemoryStreamingWriteSupport(this, mode, schema) + override def name(): String = "MemorySink" + + override def schema(): StructType = StructType(Nil) + + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.STREAMING_WRITE).asJava + } + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + new WriteBuilder with SupportsTruncate { + private var needTruncate: Boolean = false + private var inputSchema: StructType = _ + + override def truncate(): WriteBuilder = { + this.needTruncate = true + this + } + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def buildForStreaming(): StreamingWrite = { + new MemoryStreamingWrite(MemorySink.this, inputSchema, needTruncate) + } + } } private case class AddedData(batchId: Long, data: Array[Row]) @@ -85,27 +106,20 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + def write(batchId: Long, needTruncate: Boolean, newRows: Array[Row]): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - outputMode match { - case Append | Update => - val rows = AddedData(batchId, newRows) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, newRows) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySinkV2") + val rows = AddedData(batchId, newRows) + if (needTruncate) { + synchronized { + batches.clear() + batches += rows + } + } else { + synchronized { batches += rows } } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -116,25 +130,25 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider batches.clear() } - override def toString(): String = "MemorySinkV2" + override def toString(): String = "MemorySink" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryStreamingWriteSupport( - val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamingWriteSupport { +class MemoryStreamingWrite( + val sink: MemorySink, schema: StructType, needTruncate: Boolean) + extends StreamingWrite { override def createStreamingWriterFactory: MemoryWriterFactory = { - MemoryWriterFactory(outputMode, schema) + MemoryWriterFactory(schema) } override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows) + sink.write(epochId, needTruncate, newRows) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { @@ -142,13 +156,13 @@ class MemoryStreamingWriteSupport( } } -case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) +case class MemoryWriterFactory(schema: StructType) extends DataWriterFactory with StreamingDataWriterFactory { override def createWriter( partitionId: Int, taskId: Long): DataWriter[InternalRow] = { - new MemoryDataWriter(partitionId, outputMode, schema) + new MemoryDataWriter(partitionId, schema) } override def createWriter( @@ -159,7 +173,7 @@ case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) } } -class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType) +class MemoryDataWriter(partition: Int, schema: StructType) extends DataWriter[InternalRow] with Logging { private val data = mutable.Buffer[Row]() @@ -181,9 +195,9 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructTyp /** - * Used to query the data that has been written into a [[MemorySinkV2]]. + * Used to query the data that has been written into a [[MemorySink]]. */ -case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { +case class MemoryPlan(sink: MemorySink, override val output: Seq[Attribute]) extends LeafNode { private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index a605dc640dc96..18029abb08dab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -160,6 +161,7 @@ abstract class BaseSessionStateBuilder( new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallbackOrcDataSourceV2(session) +: + DataSourceResolution(conf, session.catalog(_)) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = @@ -172,6 +174,8 @@ abstract class BaseSessionStateBuilder( PreWriteCheck +: PreReadCheck +: HiveOnlyCheck +: + V2WriteSupportCheck +: + V2StreamingScanSupportCheck +: customCheckRules } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index a10bd2218eb38..da4723e34c0d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -30,7 +30,9 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems, @@ -173,22 +175,24 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo ds match { case provider: TableProvider => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = provider, conf = sparkSession.sessionState.conf) + source = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions - val dsOptions = new DataSourceOptions(options.asJava) + val dsOptions = new CaseInsensitiveStringMap(options.asJava) val table = userSpecifiedSchema match { case Some(schema) => provider.getTable(dsOptions, schema) case _ => provider.getTable(dsOptions) } + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { - case _: SupportsMicroBatchRead | _: SupportsContinuousRead => + case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => Dataset.ofRows( sparkSession, StreamingRelationV2( - provider, source, table, options, table.schema.toAttributes, v1Relation)( + provider, source, table, dsOptions, table.schema.toAttributes, v1Relation)( sparkSession)) // fallback to v1 + // TODO (SPARK-27483): we should move this fallback logic to an analyzer rule. case _ => Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ea596ba728c19..d051cf9c1d4a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -31,7 +31,9 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.{SupportsWrite, TableProvider} +import org.apache.spark.sql.sources.v2.TableCapability._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -252,16 +254,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { if (extraOptions.get("queryName").isEmpty) { throw new AnalysisException("queryName must be specified for memory sink") } - val (sink, resultDf) = trigger match { - case _: ContinuousTrigger => - val s = new MemorySinkV2() - val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) - (s, r) - case _ => - val s = new MemorySink(df.schema, outputMode) - val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) - (s, r) - } + val sink = new MemorySink() + val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes)) val chkpointLoc = extraOptions.get("checkpointLocation") val recoverFromChkpoint = outputMode == OutputMode.Complete() val query = df.sparkSession.sessionState.streamingQueryManager.startQuery( @@ -278,7 +272,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -304,30 +298,31 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") - var options = extraOptions.toMap - val sink = ds.getConstructor().newInstance() match { - case w: StreamingWriteSupportProvider - if !disabledSources.contains(w.getClass.getCanonicalName) => - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - w, df.sparkSession.sessionState.conf) - options = sessionOptions ++ extraOptions - w - case _ => - val ds = DataSource( - df.sparkSession, - className = source, - options = options, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) + val useV1Source = disabledSources.contains(cls.getCanonicalName) + + val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { + val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + source = provider, conf = df.sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dsOptions = new CaseInsensitiveStringMap(options.asJava) + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ + provider.getTable(dsOptions) match { + case table: SupportsWrite if table.supports(STREAMING_WRITE) => + table + case _ => createV1Sink() + } + } else { + createV1Sink() } df.sparkSession.sessionState.streamingQueryManager.startQuery( - options.get("queryName"), - options.get("checkpointLocation"), + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), df, - options, + extraOptions.toMap, sink, outputMode, useTempCheckpointLocation = source == "console" || source == "noop", @@ -336,6 +331,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } } + private def createV1Sink(): Sink = { + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + /** * Sets the output of the streaming query to be processed using the provided writer object. * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 881cd96cc9dc9..63fb9ed176b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table} import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -206,7 +206,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedCheckpointLocation: Option[String], df: DataFrame, extraOptions: Map[String, String], - sink: BaseStreamingSink, + sink: Table, outputMode: OutputMode, useTempCheckpointLocation: Boolean, recoverFromCheckpointLocation: Boolean, @@ -214,16 +214,20 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo triggerClock: Clock): StreamingQueryWrapper = { var deleteCheckpointOnStop = false val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified => - new Path(userSpecified).toUri.toString + new Path(userSpecified).toString }.orElse { df.sparkSession.sessionState.conf.checkpointLocation.map { location => - new Path(location, userSpecifiedName.getOrElse(UUID.randomUUID().toString)).toUri.toString + new Path(location, userSpecifiedName.getOrElse(UUID.randomUUID().toString)).toString } }.getOrElse { if (useTempCheckpointLocation) { - // Delete the temp checkpoint when a query is being stopped without errors. deleteCheckpointOnStop = true - Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath + val tempDir = Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath + logWarning("Temporary checkpoint location created which is deleted normally when" + + s" the query didn't fail: $tempDir. If it's required to delete it under any" + + s" circumstances, please set ${SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key} to" + + s" true. Important to know deleting temp checkpoint folder is best effort.") + tempDir } else { throw new AnalysisException( "checkpointLocation must be specified either " + @@ -254,7 +258,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => + case (table: SupportsWrite, trigger: ContinuousTrigger) => if (operationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } @@ -263,7 +267,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v2Sink, + table, trigger, triggerClock, outputMode, @@ -308,7 +312,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedCheckpointLocation: Option[String], df: DataFrame, extraOptions: Map[String, String], - sink: BaseStreamingSink, + sink: Table, outputMode: OutputMode, useTempCheckpointLocation: Boolean = false, recoverFromCheckpointLocation: Boolean = true, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 2612b6185fd4c..255a9f887878b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,19 +24,19 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaAdvancedDataSourceV2 implements TableProvider { @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new AdvancedScanBuilder(); } }; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java index d72ab5338aa8c..699859cfaebe1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -21,11 +21,11 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -49,10 +49,10 @@ public PartitionReaderFactory createReaderFactory() { } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(); } }; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index a513bfb26ef1c..391af5a306a16 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -20,15 +20,17 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.spark.sql.catalog.v2.expressions.Expressions; +import org.apache.spark.sql.catalog.v2.expressions.Transform; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaPartitionAwareDataSource implements TableProvider { @@ -54,10 +56,15 @@ public Partitioning outputPartitioning() { } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public Transform[] partitioning() { + return new Transform[] { Expressions.identity("i") }; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(); } }; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java new file mode 100644 index 0000000000000..f3755e18b58d5 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.util.OptionalLong; + +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableProvider; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.sql.sources.v2.reader.Statistics; +import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class JavaReportStatisticsDataSource implements TableProvider { + class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics { + @Override + public Statistics estimateStatistics() { + return new Statistics() { + @Override + public OptionalLong sizeInBytes() { + return OptionalLong.of(80); + } + + @Override + public OptionalLong numRows() { + return OptionalLong.of(10); + } + }; + } + + @Override + public InputPartition[] planInputPartitions() { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 5); + partitions[1] = new JavaRangeInputPartition(5, 10); + return partitions; + } + } + + @Override + public Table getTable(CaseInsensitiveStringMap options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return new MyScanBuilder(); + } + }; + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 815d57ba94139..3800a94f88898 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,11 +17,11 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaSchemaRequiredDataSource implements TableProvider { @@ -45,7 +45,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(DataSourceOptions options, StructType schema) { + public Table getTable(CaseInsensitiveStringMap options, StructType schema) { return new JavaSimpleBatchTable() { @Override @@ -54,14 +54,14 @@ public StructType schema() { } @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(schema); } }; } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { throw new IllegalArgumentException("requires a user-supplied schema"); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java index cb5954d5a6211..9b0eb610a206f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java @@ -18,15 +18,23 @@ package test.org.apache.spark.sql.sources.v2; import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.SupportsRead; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableCapability; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -abstract class JavaSimpleBatchTable implements Table, SupportsBatchRead { +abstract class JavaSimpleBatchTable implements Table, SupportsRead { + private static final Set CAPABILITIES = new HashSet<>(Arrays.asList( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.TRUNCATE)); @Override public StructType schema() { @@ -37,6 +45,11 @@ public StructType schema() { public String name() { return this.getClass().toString(); } + + @Override + public Set capabilities() { + return CAPABILITIES; + } } abstract class JavaSimpleScanBuilder implements ScanBuilder, Scan, Batch { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 852c4546df885..7474f36c97f75 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,10 +17,10 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaSimpleDataSourceV2 implements TableProvider { @@ -36,10 +36,10 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(); } }; diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a36b0cfa6ff18..914af589384df 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider +org.apache.spark.sql.streaming.sources.FakeWriteOnly org.apache.spark.sql.streaming.sources.FakeNoWrite org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql new file mode 100644 index 0000000000000..bc144d01cee64 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql @@ -0,0 +1,27 @@ +-- Test tables +CREATE table desc_temp1 (key int COMMENT 'column_comment', val string) USING PARQUET; +CREATE table desc_temp2 (key int, val string) USING PARQUET; + +-- Simple Describe query +DESC SELECT key, key + 1 as plusone FROM desc_temp1; +DESC QUERY SELECT * FROM desc_temp2; +DESC SELECT key, COUNT(*) as count FROM desc_temp1 group by key; +DESC SELECT 10.00D as col1; +DESC QUERY SELECT key FROM desc_temp1 UNION ALL select CAST(1 AS DOUBLE); +DESC QUERY VALUES(1.00D, 'hello') as tab1(col1, col2); +DESC QUERY FROM desc_temp1 a SELECT *; + + +-- Error cases. +DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s; +DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s; +DESCRIBE INSERT INTO desc_temp1 values (1, 'val1'); +DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2; +DESCRIBE + FROM desc_temp1 a + insert into desc_temp1 select * + insert into desc_temp2 select *; + +-- cleanup +DROP TABLE desc_temp1; +DROP TABLE desc_temp2; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out new file mode 100644 index 0000000000000..36cb314884779 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out @@ -0,0 +1,171 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 16 + + +-- !query 0 +CREATE table desc_temp1 (key int COMMENT 'column_comment', val string) USING PARQUET +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE table desc_temp2 (key int, val string) USING PARQUET +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +DESC SELECT key, key + 1 as plusone FROM desc_temp1 +-- !query 2 schema +struct +-- !query 2 output +key int column_comment +plusone int + + +-- !query 3 +DESC QUERY SELECT * FROM desc_temp2 +-- !query 3 schema +struct +-- !query 3 output +key int +val string + + +-- !query 4 +DESC SELECT key, COUNT(*) as count FROM desc_temp1 group by key +-- !query 4 schema +struct +-- !query 4 output +key int column_comment +count bigint + + +-- !query 5 +DESC SELECT 10.00D as col1 +-- !query 5 schema +struct +-- !query 5 output +col1 double + + +-- !query 6 +DESC QUERY SELECT key FROM desc_temp1 UNION ALL select CAST(1 AS DOUBLE) +-- !query 6 schema +struct +-- !query 6 output +key double + + +-- !query 7 +DESC QUERY VALUES(1.00D, 'hello') as tab1(col1, col2) +-- !query 7 schema +struct +-- !query 7 output +col1 double +col2 string + + +-- !query 8 +DESC QUERY FROM desc_temp1 a SELECT * +-- !query 8 schema +struct +-- !query 8 output +key int column_comment +val string + + +-- !query 9 +DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'AS' expecting {, '.'}(line 1, pos 12) + +== SQL == +DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s +------------^^^ + + +-- !query 10 +DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 's' expecting {, '.'}(line 1, pos 20) + +== SQL == +DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s +--------------------^^^ + + +-- !query 11 +DESCRIBE INSERT INTO desc_temp1 values (1, 'val1') +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'desc_temp1' expecting {, '.'}(line 1, pos 21) + +== SQL == +DESCRIBE INSERT INTO desc_temp1 values (1, 'val1') +---------------------^^^ + + +-- !query 12 +DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'desc_temp1' expecting {, '.'}(line 1, pos 21) + +== SQL == +DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2 +---------------------^^^ + + +-- !query 13 +DESCRIBE + FROM desc_temp1 a + insert into desc_temp1 select * + insert into desc_temp2 select * +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'insert' expecting {, '(', ',', 'SELECT', 'WHERE', 'GROUP', 'ORDER', 'HAVING', 'LIMIT', 'JOIN', 'CROSS', 'INNER', 'LEFT', 'RIGHT', 'FULL', 'NATURAL', 'PIVOT', 'LATERAL', 'WINDOW', 'UNION', 'EXCEPT', 'MINUS', 'INTERSECT', 'SORT', 'CLUSTER', 'DISTRIBUTE', 'ANTI'}(line 3, pos 5) + +== SQL == +DESCRIBE + FROM desc_temp1 a + insert into desc_temp1 select * +-----^^^ + insert into desc_temp2 select * + + +-- !query 14 +DROP TABLE desc_temp1 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP TABLE desc_temp2 +-- !query 15 schema +struct<> +-- !query 15 output + diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata new file mode 100644 index 0000000000000..3071b0dfc550b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata @@ -0,0 +1 @@ +{"id":"09be7fb3-49d8-48a6-840d-e9c2ad92a898"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0 new file mode 100644 index 0000000000000..a0a567631fd14 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1549649384149,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"200"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet new file mode 100644 index 0000000000000..1b2919b25c381 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet differ diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0 new file mode 100644 index 0000000000000..79768f89d6eca --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0 @@ -0,0 +1,2 @@ +v1 +{"path":"file://TEMPDIR/output%20%25@%23output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet","size":404,"isDir":false,"modificationTime":1549649385000,"blockReplication":1,"blockSize":33554432,"action":"add"} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 54342b691109d..e46802f69ed67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -334,83 +334,97 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath - // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { - // write path - Seq("csv", "json", "parquet", "orc").foreach { format => - var msg = intercept[AnalysisException] { - sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.contains("Cannot save interval data type into external storage.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new IntervalData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + Seq(true).foreach { useV1 => + val useV1List = if (useV1) { + "orc" + } else { + "" } + def errorMessage(format: String, isWrite: Boolean): String = { + if (isWrite && (useV1 || format != "orc")) { + "cannot save interval data type into external storage." + } else { + s"$format data source does not support calendarinterval data type." + } + } + + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true))) + } - // read path - Seq("parquet", "csv").foreach { format => - var msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + } } } } } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc", - SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { - withTempDir { dir => - val tempDir = new File(dir, "files").getCanonicalPath - - Seq("parquet", "csv", "orc").foreach { format => - // write path - var msg = intercept[AnalysisException] { - sql("select null").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new NullData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - // read path - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) + Seq(true).foreach { useV1 => + val useV1List = if (useV1) { + "orc" + } else { + "" + } + def errorMessage(format: String): String = { + s"$format data source does not support null data type." + } + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List, + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("parquet", "csv", "orc").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 24b312348bd67..62f3f98bf28ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} import org.apache.spark.sql.execution.HiveResult.hiveResultString -import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand} +import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -277,7 +277,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Returns true if the plan is supposed to be sorted. def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false - case _: DescribeTableCommand | _: DescribeColumnCommand => true + case _: DescribeCommandBase | _: DescribeColumnCommand => true case PhysicalOperation(_, _, Sort(_, true, _)) => true case _ => plan.children.iterator.exists(isSorted) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 9f33feb1950c7..881268440ccd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -234,6 +234,9 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = delegate.parseFunctionIdentifier(sqlText) + override def parseMultipartIdentifier(sqlText: String): Seq[String] = + delegate.parseMultipartIdentifier(sqlText) + override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 31b9bcdafbab8..be3d0794d4036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -215,19 +215,6 @@ class SparkSqlParserSuite extends AnalysisTest { "no viable alternative at input") } - test("create table using - schema") { - assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet", - createTableUsing( - table = "my_tab", - schema = (new StructType) - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType) - ) - ) - intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", - "no viable alternative at input") - } - test("create view as insert into table") { // Single insert query intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", @@ -240,15 +227,20 @@ class SparkSqlParserSuite extends AnalysisTest { } test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { + assertEqual("describe t", + DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = false)) assertEqual("describe table t", - DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = false)) + DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = false)) assertEqual("describe table extended t", - DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = true)) + DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = true)) assertEqual("describe table formatted t", - DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = true)) + DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = true)) + } + + test("describe query") { + val query = "SELECT * FROM t" + assertEqual("DESCRIBE QUERY " + query, DescribeQueryCommand(parser.parsePlan(query))) + assertEqual("DESCRIBE " + query, DescribeQueryCommand(parser.parsePlan(query))) } test("describe table column") { @@ -387,4 +379,12 @@ class SparkSqlParserSuite extends AnalysisTest { "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") } + + test("database and schema tokens are interchangeable") { + assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) + assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) + assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", + parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) + assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index c36872a6a5289..86874b9817c20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index e0ccae15f1d05..0dd11c1e518e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -32,13 +32,12 @@ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Generate, InsertIntoDir, LogicalPlan} -import org.apache.spark.sql.catalyst.plans.logical.{Project, ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, InsertIntoDir, LogicalPlan, Project, ScriptTransformation} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class DDLParserSuite extends PlanTest with SharedSQLContext { @@ -415,173 +414,28 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } - test("create table - with partitioned by") { - val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + - "USING parquet PARTITIONED BY (a)" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType() - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType), - provider = Some("parquet"), - partitionColumnNames = Seq("a") - ) - - parser.parsePlan(query) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $query") + test("Duplicate clauses - create hive table") { + def createTableHeader(duplicateClause: String): String = { + s"CREATE TABLE my_tab(a INT, b STRING) STORED AS parquet $duplicateClause $duplicateClause" } - } - - test("create table - with bucket") { - val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + - "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType().add("a", IntegerType).add("b", StringType), - provider = Some("parquet"), - bucketSpec = Some(BucketSpec(5, Seq("a"), Seq("b"))) - ) - - parser.parsePlan(query) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $query") - } - } - - test("create table - with comment") { - val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType().add("a", IntegerType).add("b", StringType), - provider = Some("parquet"), - comment = Some("abc")) - - parser.parsePlan(sql) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $sql") - } - } - - test("create table - with table properties") { - val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType().add("a", IntegerType).add("b", StringType), - provider = Some("parquet"), - properties = Map("test" -> "test")) - - parser.parsePlan(sql) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $sql") - } - } - - test("Duplicate clauses - create table") { - def createTableHeader(duplicateClause: String, isNative: Boolean): String = { - val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" - s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" - } - - Seq(true, false).foreach { isNative => - intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), - "Found duplicate clauses: TBLPROPERTIES") - intercept(createTableHeader("LOCATION '/tmp/file'", isNative), - "Found duplicate clauses: LOCATION") - intercept(createTableHeader("COMMENT 'a table'", isNative), - "Found duplicate clauses: COMMENT") - intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), - "Found duplicate clauses: CLUSTERED BY") - } - - // Only for native data source tables - intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), - "Found duplicate clauses: PARTITIONED BY") - - // Only for Hive serde tables - intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'"), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'"), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"), + "Found duplicate clauses: CLUSTERED BY") + intercept(createTableHeader("PARTITIONED BY (k int)"), "Found duplicate clauses: PARTITIONED BY") - intercept(createTableHeader("STORED AS parquet", isNative = false), + intercept(createTableHeader("STORED AS parquet"), "Found duplicate clauses: STORED AS/BY") intercept( - createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'"), "Found duplicate clauses: ROW FORMAT") } - test("create table - with location") { - val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))), - schema = new StructType().add("a", IntegerType).add("b", StringType), - provider = Some("parquet")) - - parser.parsePlan(v1) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $v1") - } - - val v2 = - """ - |CREATE TABLE my_tab(a INT, b STRING) - |USING parquet - |OPTIONS (path '/tmp/file') - |LOCATION '/tmp/file' - """.stripMargin - val e = intercept[ParseException] { - parser.parsePlan(v2) - } - assert(e.message.contains("you can only specify one of them.")) - } - - test("create table - byte length literal table name") { - val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("2g", Some("1m")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType().add("a", IntegerType), - provider = Some("parquet")) - - parser.parsePlan(sql) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $sql") - } - } - test("insert overwrite directory") { val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a" parser.parsePlan(v1) match { @@ -1032,64 +886,6 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(e.contains("Found an empty partition key 'b'")) } - test("drop table") { - val tableName1 = "db.tab" - val tableName2 = "tab" - - val parsed = Seq( - s"DROP TABLE $tableName1", - s"DROP TABLE IF EXISTS $tableName1", - s"DROP TABLE $tableName2", - s"DROP TABLE IF EXISTS $tableName2", - s"DROP TABLE $tableName2 PURGE", - s"DROP TABLE IF EXISTS $tableName2 PURGE" - ).map(parser.parsePlan) - - val expected = Seq( - DropTableCommand(TableIdentifier("tab", Option("db")), ifExists = false, isView = false, - purge = false), - DropTableCommand(TableIdentifier("tab", Option("db")), ifExists = true, isView = false, - purge = false), - DropTableCommand(TableIdentifier("tab", None), ifExists = false, isView = false, - purge = false), - DropTableCommand(TableIdentifier("tab", None), ifExists = true, isView = false, - purge = false), - DropTableCommand(TableIdentifier("tab", None), ifExists = false, isView = false, - purge = true), - DropTableCommand(TableIdentifier("tab", None), ifExists = true, isView = false, - purge = true)) - - parsed.zip(expected).foreach { case (p, e) => comparePlans(p, e) } - } - - test("drop view") { - val viewName1 = "db.view" - val viewName2 = "view" - - val parsed1 = parser.parsePlan(s"DROP VIEW $viewName1") - val parsed2 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName1") - val parsed3 = parser.parsePlan(s"DROP VIEW $viewName2") - val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2") - - val expected1 = - DropTableCommand(TableIdentifier("view", Option("db")), ifExists = false, isView = true, - purge = false) - val expected2 = - DropTableCommand(TableIdentifier("view", Option("db")), ifExists = true, isView = true, - purge = false) - val expected3 = - DropTableCommand(TableIdentifier("view", None), ifExists = false, isView = true, - purge = false) - val expected4 = - DropTableCommand(TableIdentifier("view", None), ifExists = true, isView = true, - purge = false) - - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - } - test("show columns") { val sql1 = "SHOW COLUMNS FROM t1" val sql2 = "SHOW COLUMNS IN db1.t1" @@ -1165,84 +961,6 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { comparePlans(parsed, expected) } - test("support for other types in OPTIONS") { - val sql = - """ - |CREATE TABLE table_name USING json - |OPTIONS (a 1, b 0.1, c TRUE) - """.stripMargin - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("table_name"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true") - ), - schema = new StructType, - provider = Some("json") - ) - - parser.parsePlan(sql) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $sql") - } - } - - test("Test CTAS against data source tables") { - val s1 = - """ - |CREATE TABLE IF NOT EXISTS mydb.page_view - |USING parquet - |COMMENT 'This is the staging page view table' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src - """.stripMargin - - val s2 = - """ - |CREATE TABLE IF NOT EXISTS mydb.page_view - |USING parquet - |LOCATION '/user/external/page_view' - |COMMENT 'This is the staging page view table' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src - """.stripMargin - - val s3 = - """ - |CREATE TABLE IF NOT EXISTS mydb.page_view - |USING parquet - |COMMENT 'This is the staging page view table' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src - """.stripMargin - - checkParsing(s1) - checkParsing(s2) - checkParsing(s3) - - def checkParsing(sql: String): Unit = { - val (desc, exists) = extractTableDesc(sql) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.provider == Some("parquet")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) - } - } - test("Test CTAS #1") { val s1 = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala new file mode 100644 index 0000000000000..06f7332086372 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -0,0 +1,504 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.net.URI +import java.util.Locale + +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, TableCatalog, TestTableCatalog} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan} +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceResolution} +import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class PlanResolutionSuite extends AnalysisTest { + import CatalystSqlParser._ + + private val orc2 = classOf[OrcDataSourceV2].getName + + private val testCat: TableCatalog = { + val newCatalog = new TestTableCatalog + newCatalog.initialize("testcat", CaseInsensitiveStringMap.empty()) + newCatalog + } + + private val lookupCatalog: String => CatalogPlugin = { + case "testcat" => + testCat + case name => + throw new CatalogNotFoundException(s"No such catalog: $name") + } + + def parseAndResolve(query: String): LogicalPlan = { + val newConf = conf.copy() + newConf.setConfString("spark.sql.default.catalog", "testcat") + DataSourceResolution(newConf, lookupCatalog).apply(parsePlan(query)) + } + + private def parseResolveCompare(query: String, expected: LogicalPlan): Unit = + comparePlans(parseAndResolve(query), expected, checkAnalysis = true) + + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { + parseAndResolve(sql).collect { + case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + }.head + } + + test("create table - with partitioned by") { + val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + + "USING parquet PARTITIONED BY (a)" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType() + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType), + provider = Some("parquet"), + partitionColumnNames = Seq("a") + ) + + parseAndResolve(query) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table - partitioned by transforms") { + val transforms = Seq( + "bucket(16, b)", "years(ts)", "months(ts)", "days(ts)", "hours(ts)", "foo(a, 'bar', 34)", + "bucket(32, b), days(ts)") + transforms.foreach { transform => + val query = + s""" + |CREATE TABLE my_tab(a INT, b STRING) USING parquet + |PARTITIONED BY ($transform) + """.stripMargin + + val ae = intercept[AnalysisException] { + parseAndResolve(query) + } + + assert(ae.message + .contains(s"Transforms cannot be converted to partition columns: $transform")) + } + } + + test("create table - with bucket") { + val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet"), + bucketSpec = Some(BucketSpec(5, Seq("a"), Seq("b"))) + ) + + parseAndResolve(query) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table - with comment") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet"), + comment = Some("abc")) + + parseAndResolve(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with table properties") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet"), + properties = Map("test" -> "test")) + + parseAndResolve(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with location") { + val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))), + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet")) + + parseAndResolve(v1) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $v1") + } + + val v2 = + """ + |CREATE TABLE my_tab(a INT, b STRING) + |USING parquet + |OPTIONS (path '/tmp/file') + |LOCATION '/tmp/file' + """.stripMargin + val e = intercept[AnalysisException] { + parseAndResolve(v2) + } + assert(e.message.contains("you can only specify one of them.")) + } + + test("create table - byte length literal table name") { + val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("2g", Some("1m")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType), + provider = Some("parquet")) + + parseAndResolve(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("support for other types in OPTIONS") { + val sql = + """ + |CREATE TABLE table_name USING json + |OPTIONS (a 1, b 0.1, c TRUE) + """.stripMargin + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("table_name"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty.copy( + properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true") + ), + schema = new StructType, + provider = Some("json") + ) + + parseAndResolve(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database.contains("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri.contains(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment.contains("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider.contains("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + + test("Test v2 CreateTable with known catalog in identifier") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS testcat.mydb.table_name ( + | id bigint, + | description string, + | point struct) + |USING parquet + |COMMENT 'table comment' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |OPTIONS (path 's3://bucket/path/to/data', other 20) + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "other" -> "20", + "provider" -> "parquet", + "location" -> "s3://bucket/path/to/data", + "comment" -> "table comment") + + parseAndResolve(sql) match { + case create: CreateV2Table => + assert(create.catalog.name == "testcat") + assert(create.tableName == Identifier.of(Array("mydb"), "table_name")) + assert(create.tableSchema == new StructType() + .add("id", LongType) + .add("description", StringType) + .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) + assert(create.partitioning.isEmpty) + assert(create.properties == expectedProperties) + assert(create.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test v2 CreateTable with data source v2 provider") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS mydb.page_view ( + | id bigint, + | description string, + | point struct) + |USING $orc2 + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "provider" -> orc2, + "location" -> "/user/external/page_view", + "comment" -> "This is the staging page view table") + + parseAndResolve(sql) match { + case create: CreateV2Table => + assert(create.catalog.name == "testcat") + assert(create.tableName == Identifier.of(Array("mydb"), "page_view")) + assert(create.tableSchema == new StructType() + .add("id", LongType) + .add("description", StringType) + .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) + assert(create.partitioning.isEmpty) + assert(create.properties == expectedProperties) + assert(create.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test v2 CTAS with known catalog in identifier") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS testcat.mydb.table_name + |USING parquet + |COMMENT 'table comment' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |OPTIONS (path 's3://bucket/path/to/data', other 20) + |AS SELECT * FROM src + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "other" -> "20", + "provider" -> "parquet", + "location" -> "s3://bucket/path/to/data", + "comment" -> "table comment") + + parseAndResolve(sql) match { + case ctas: CreateTableAsSelect => + assert(ctas.catalog.name == "testcat") + assert(ctas.tableName == Identifier.of(Array("mydb"), "table_name")) + assert(ctas.properties == expectedProperties) + assert(ctas.writeOptions == Map("other" -> "20")) + assert(ctas.partitioning.isEmpty) + assert(ctas.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test v2 CTAS with data source v2 provider") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING $orc2 + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "provider" -> orc2, + "location" -> "/user/external/page_view", + "comment" -> "This is the staging page view table") + + parseAndResolve(sql) match { + case ctas: CreateTableAsSelect => + assert(ctas.catalog.name == "testcat") + assert(ctas.tableName == Identifier.of(Array("mydb"), "page_view")) + assert(ctas.properties == expectedProperties) + assert(ctas.writeOptions.isEmpty) + assert(ctas.partitioning.isEmpty) + assert(ctas.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("drop table") { + val tableName1 = "db.tab" + val tableIdent1 = TableIdentifier("tab", Option("db")) + val tableName2 = "tab" + val tableIdent2 = TableIdentifier("tab", None) + + parseResolveCompare(s"DROP TABLE $tableName1", + DropTableCommand(tableIdent1, ifExists = false, isView = false, purge = false)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName1", + DropTableCommand(tableIdent1, ifExists = true, isView = false, purge = false)) + parseResolveCompare(s"DROP TABLE $tableName2", + DropTableCommand(tableIdent2, ifExists = false, isView = false, purge = false)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2", + DropTableCommand(tableIdent2, ifExists = true, isView = false, purge = false)) + parseResolveCompare(s"DROP TABLE $tableName2 PURGE", + DropTableCommand(tableIdent2, ifExists = false, isView = false, purge = true)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2 PURGE", + DropTableCommand(tableIdent2, ifExists = true, isView = false, purge = true)) + } + + test("drop table in v2 catalog") { + val tableName1 = "testcat.db.tab" + val tableIdent1 = Identifier.of(Array("db"), "tab") + val tableName2 = "testcat.tab" + val tableIdent2 = Identifier.of(Array.empty, "tab") + + parseResolveCompare(s"DROP TABLE $tableName1", + DropTable(testCat, tableIdent1, ifExists = false)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName1", + DropTable(testCat, tableIdent1, ifExists = true)) + parseResolveCompare(s"DROP TABLE $tableName2", + DropTable(testCat, tableIdent2, ifExists = false)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2", + DropTable(testCat, tableIdent2, ifExists = true)) + } + + test("drop view") { + val viewName1 = "db.view" + val viewIdent1 = TableIdentifier("view", Option("db")) + val viewName2 = "view" + val viewIdent2 = TableIdentifier("view") + + parseResolveCompare(s"DROP VIEW $viewName1", + DropTableCommand(viewIdent1, ifExists = false, isView = true, purge = false)) + parseResolveCompare(s"DROP VIEW IF EXISTS $viewName1", + DropTableCommand(viewIdent1, ifExists = true, isView = true, purge = false)) + parseResolveCompare(s"DROP VIEW $viewName2", + DropTableCommand(viewIdent2, ifExists = false, isView = true, purge = false)) + parseResolveCompare(s"DROP VIEW IF EXISTS $viewName2", + DropTableCommand(viewIdent2, ifExists = true, isView = true, purge = false)) + } + + test("drop view in v2 catalog") { + intercept[AnalysisException] { + parseAndResolve("DROP VIEW testcat.db.view") + }.getMessage.toLowerCase(Locale.ROOT).contains( + "view support in catalog has not been implemented") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index f20aded169e44..2f5d5551c5df0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -219,6 +219,13 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { IsNotNull(attrInt))), None) } + test("SPARK-26865 DataSourceV2Strategy should push normalized filters") { + val attrInt = 'cint.int + assertResult(Seq(IsNotNull(attrInt))) { + DataSourceStrategy.normalizeFilters(Seq(IsNotNull(attrInt.withName("CiNt"))), Seq(attrInt)) + } + } + /** * Translate the given Catalyst [[Expression]] into data source [[sources.Filter]] * then verify against the given [[sources.Filter]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index cccd8e9ee8bd1..034454d21d7ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsR import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -58,7 +57,7 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { case PhysicalOperation(_, filters, DataSourceV2Relation(orcTable: OrcTable, _, options)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - val scanBuilder = orcTable.newScanBuilder(new DataSourceOptions(options.asJava)) + val scanBuilder = orcTable.newScanBuilder(options) scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) val pushedFilters = scanBuilder.pushedFilters() assert(pushedFilters.nonEmpty, "No filter is pushed down") @@ -102,7 +101,7 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { case PhysicalOperation(_, filters, DataSourceV2Relation(orcTable: OrcTable, _, options)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - val scanBuilder = orcTable.newScanBuilder(new DataSourceOptions(options.asJava)) + val scanBuilder = orcTable.newScanBuilder(options) scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) val pushedFilters = scanBuilder.pushedFilters() if (noneSupported) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala index 4a695ac74c476..b4d92c3b2d2fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import org.apache.hadoop.fs.{Path, PathFilter} + import org.apache.spark.SparkConf import org.apache.spark.sql._ import org.apache.spark.sql.internal.SQLConf @@ -30,6 +32,10 @@ case class OrcParData(intField: Int, stringField: String) // The data that also includes the partitioning key case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + abstract class OrcPartitionDiscoveryTest extends OrcTest { val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" @@ -226,6 +232,23 @@ abstract class OrcPartitionDiscoveryTest extends OrcTest { } } } + + test("SPARK-27162: handle pathfilter configuration correctly") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = spark.range(2) + df.write.orc(path + "/p=1") + df.write.orc(path + "/p=2") + assert(spark.read.orc(path).count() === 4) + + val extraOptions = Map( + "mapred.input.pathFilter.class" -> classOf[TestFileFilter].getName, + "mapreduce.input.pathFilter.class" -> classOf[TestFileFilter].getName + ) + assert(spark.read.options(extraOptions).orc(path).count() === 2) + } + } } class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala new file mode 100644 index 0000000000000..8a0450fce76a1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{Offset, Source, StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.sources.v2.{Table, TableCapability, TableProvider} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class V2StreamingScanSupportCheckSuite extends SparkFunSuite with SharedSparkSession { + import TableCapability._ + + private def createStreamingRelation(table: Table, v1Relation: Option[StreamingRelation]) = { + StreamingRelationV2(FakeTableProvider, "fake", table, CaseInsensitiveStringMap.empty(), + FakeTableProvider.schema.toAttributes, v1Relation)(spark) + } + + private def createStreamingRelationV1() = { + StreamingRelation(DataSource(spark, classOf[FakeStreamSourceProvider].getName)) + } + + test("check correct plan") { + val plan1 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ), None) + val plan2 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None) + val plan3 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ, CONTINUOUS_READ), None) + val plan4 = createStreamingRelationV1() + + V2StreamingScanSupportCheck(Union(plan1, plan1)) + V2StreamingScanSupportCheck(Union(plan2, plan2)) + V2StreamingScanSupportCheck(Union(plan1, plan3)) + V2StreamingScanSupportCheck(Union(plan2, plan3)) + V2StreamingScanSupportCheck(Union(plan1, plan4)) + V2StreamingScanSupportCheck(Union(plan3, plan4)) + } + + test("table without scan capability") { + val e = intercept[AnalysisException] { + V2StreamingScanSupportCheck(createStreamingRelation(CapabilityTable(), None)) + } + assert(e.message.contains("does not support either micro-batch or continuous scan")) + } + + test("mix micro-batch only and continuous only") { + val plan1 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ), None) + val plan2 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None) + + val e = intercept[AnalysisException] { + V2StreamingScanSupportCheck(Union(plan1, plan2)) + } + assert(e.message.contains( + "The streaming sources in a query do not have a common supported execution mode")) + } + + test("mix continuous only and v1 relation") { + val plan1 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None) + val plan2 = createStreamingRelationV1() + val e = intercept[AnalysisException] { + V2StreamingScanSupportCheck(Union(plan1, plan2)) + } + assert(e.message.contains( + "The streaming sources in a query do not have a common supported execution mode")) + } +} + +private object FakeTableProvider extends TableProvider { + val schema = new StructType().add("i", "int") + + override def getTable(options: CaseInsensitiveStringMap): Table = { + throw new UnsupportedOperationException + } +} + +private case class CapabilityTable(_capabilities: TableCapability*) extends Table { + override def name(): String = "capability_test_table" + override def schema(): StructType = FakeTableProvider.schema + override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava +} + +private class FakeStreamSourceProvider extends StreamSourceProvider { + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + "fake" -> FakeTableProvider.schema + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + new Source { + override def schema: StructType = FakeTableProvider.schema + override def getOffset: Option[Offset] = { + throw new UnsupportedOperationException + } + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + throw new UnsupportedOperationException + } + override def stop(): Unit = {} + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 3bc36ce55d902..3ead91fcf712a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -22,6 +22,8 @@ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -36,7 +38,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink + val addBatch = addBatchFunc(sink, false) _ // Before adding data, check output assert(sink.latestBatchId === None) @@ -44,25 +47,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, Seq.empty) // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) + addBatch(0, 1 to 3) assert(sink.latestBatchId === Some(0)) checkAnswer(sink.latestBatchData, 1 to 3) checkAnswer(sink.allData, 1 to 3) // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) + addBatch(1, 4 to 6) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data // Re-add batch 1 with different data, should not be added and outputs should not be changed - sink.addBatch(1, 7 to 9) + addBatch(1, 7 to 9) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 1 to 6) // Add batch 2 and check outputs - sink.addBatch(2, 7 to 9) + addBatch(2, 7 to 9) assert(sink.latestBatchId === Some(2)) checkAnswer(sink.latestBatchData, 7 to 9) checkAnswer(sink.allData, 1 to 9) @@ -70,7 +73,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update) + val sink = new MemorySink + val addBatch = addBatchFunc(sink, false) _ // Before adding data, check output assert(sink.latestBatchId === None) @@ -78,25 +82,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, Seq.empty) // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) + addBatch(0, 1 to 3) assert(sink.latestBatchId === Some(0)) checkAnswer(sink.latestBatchData, 1 to 3) checkAnswer(sink.allData, 1 to 3) // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) + addBatch(1, 4 to 6) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data // Re-add batch 1 with different data, should not be added and outputs should not be changed - sink.addBatch(1, 7 to 9) + addBatch(1, 7 to 9) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 1 to 6) // Add batch 2 and check outputs - sink.addBatch(2, 7 to 9) + addBatch(2, 7 to 9) assert(sink.latestBatchId === Some(2)) checkAnswer(sink.latestBatchData, 7 to 9) checkAnswer(sink.allData, 1 to 9) @@ -104,7 +108,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete) + val sink = new MemorySink + val addBatch = addBatchFunc(sink, true) _ // Before adding data, check output assert(sink.latestBatchId === None) @@ -112,25 +117,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, Seq.empty) // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) + addBatch(0, 1 to 3) assert(sink.latestBatchId === Some(0)) checkAnswer(sink.latestBatchData, 1 to 3) checkAnswer(sink.allData, 1 to 3) // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) + addBatch(1, 4 to 6) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 4 to 6) // new data should replace old data // Re-add batch 1 with different data, should not be added and outputs should not be changed - sink.addBatch(1, 7 to 9) + addBatch(1, 7 to 9) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 4 to 6) // Add batch 2 and check outputs - sink.addBatch(2, 7 to 9) + addBatch(2, 7 to 9) assert(sink.latestBatchId === Some(2)) checkAnswer(sink.latestBatchData, 7 to 9) checkAnswer(sink.allData, 7 to 9) @@ -211,18 +216,19 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) - val plan = new MemoryPlan(sink) + val sink = new MemorySink + val addBatch = addBatchFunc(sink, false) _ + val plan = new MemoryPlan(sink, schema.toAttributes) // Before adding data, check output checkAnswer(sink.allData, Seq.empty) assert(plan.stats.sizeInBytes === 0) - sink.addBatch(0, 1 to 3) + addBatch(0, 1 to 3) plan.invalidateStatsCache() assert(plan.stats.sizeInBytes === 36) - sink.addBatch(1, 4 to 6) + addBatch(1, 4 to 6) plan.invalidateStatsCache() assert(plan.stats.sizeInBytes === 72) } @@ -285,6 +291,50 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { } } + test("data writer") { + val partition = 1234 + val writer = new MemoryDataWriter( + partition, new StructType().add("i", "int")) + writer.write(InternalRow(1)) + writer.write(InternalRow(2)) + writer.write(InternalRow(44)) + val msg = writer.commit() + assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) + assert(msg.partition == partition) + + // Buffer should be cleared, so repeated commits should give empty. + assert(writer.commit().data.isEmpty) + } + + test("streaming writer") { + val sink = new MemorySink + val write = new MemoryStreamingWrite( + sink, new StructType().add("i", "int"), needTruncate = false) + write.commit(0, + Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) + )) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + write.commit(19, + Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))) + )) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) + } + + private def addBatchFunc(sink: MemorySink, needTruncate: Boolean)( + batchId: Long, + vals: Seq[Int]): Unit = { + sink.write(batchId, needTruncate, vals.map(Row(_)).toArray) + } + private def checkAnswer(rows: Seq[Row], expected: Seq[Int])(implicit schema: StructType): Unit = { checkAnswer( sqlContext.createDataFrame(sparkContext.makeRDD(rows), schema), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala deleted file mode 100644 index 61857365ac989..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.scalatest.BeforeAndAfter - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.streaming.{OutputMode, StreamTest} -import org.apache.spark.sql.types.StructType - -class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { - test("data writer") { - val partition = 1234 - val writer = new MemoryDataWriter( - partition, OutputMode.Append(), new StructType().add("i", "int")) - writer.write(InternalRow(1)) - writer.write(InternalRow(2)) - writer.write(InternalRow(44)) - val msg = writer.commit() - assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) - assert(msg.partition == partition) - - // Buffer should be cleared, so repeated commits should give empty. - assert(writer.commit().data.isEmpty) - } - - test("streaming writer") { - val sink = new MemorySinkV2 - val writeSupport = new MemoryStreamingWriteSupport( - sink, OutputMode.Append(), new StructType().add("i", "int")) - writeSupport.commit(0, - Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) - )) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writeSupport.commit(19, - Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) - )) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index d0418f893143e..ef88598fcb11b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relati import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ManualClock class RateStreamProviderSuite extends StreamTest { @@ -39,7 +39,7 @@ class RateStreamProviderSuite extends StreamTest { import testImplicits._ case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { case r: StreamingDataSourceV2Relation @@ -135,7 +135,7 @@ class RateStreamProviderSuite extends StreamTest { withTempDir { temp => val stream = new RateStreamMicroBatchStream( rowsPerSecond = 100, - options = new DataSourceOptions(Map("useManualClock" -> "true").asJava), + options = new CaseInsensitiveStringMap(Map("useManualClock" -> "true").asJava), checkpointLocation = temp.getCanonicalPath) stream.clock.asInstanceOf[ManualClock].advance(100000) val startOffset = stream.initialOffset() @@ -154,7 +154,7 @@ class RateStreamProviderSuite extends StreamTest { withTempDir { temp => val stream = new RateStreamMicroBatchStream( rowsPerSecond = 20, - options = DataSourceOptions.empty(), + options = CaseInsensitiveStringMap.empty(), checkpointLocation = temp.getCanonicalPath) val partitions = stream.planInputPartitions(LongOffset(0L), LongOffset(1L)) val readerFactory = stream.createReaderFactory() @@ -173,7 +173,7 @@ class RateStreamProviderSuite extends StreamTest { val stream = new RateStreamMicroBatchStream( rowsPerSecond = 33, numPartitions = 11, - options = DataSourceOptions.empty(), + options = CaseInsensitiveStringMap.empty(), checkpointLocation = temp.getCanonicalPath) val partitions = stream.planInputPartitions(LongOffset(0L), LongOffset(1L)) val readerFactory = stream.createReaderFactory() @@ -305,12 +305,11 @@ class RateStreamProviderSuite extends StreamTest { .load() } assert(exception.getMessage.contains( - "rate source does not support user-specified schema")) + "RateStreamProvider source does not support user-specified schema")) } test("continuous data") { - val stream = new RateStreamContinuousStream( - rowsPerSecond = 20, numPartitions = 2, options = DataSourceOptions.empty()) + val stream = new RateStreamContinuousStream(rowsPerSecond = 20, numPartitions = 2) val partitions = stream.planInputPartitions(stream.initialOffset) val readerFactory = stream.createContinuousReaderFactory() assert(partitions.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 33c65d784fba6..3c451e0538721 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -35,11 +35,11 @@ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relati import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { @@ -55,7 +55,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before private var serverThread: ServerThread = null case class AddSocketData(data: String*) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { require( query.nonEmpty, "Cannot add data when there is no query for finding the active socket source") @@ -176,13 +176,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.getTable(new DataSourceOptions(Map.empty[String, String].asJava)) + provider.getTable(CaseInsensitiveStringMap.empty()) } intercept[AnalysisException] { - provider.getTable(new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.getTable(new CaseInsensitiveStringMap(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.getTable(new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.getTable(new CaseInsensitiveStringMap(Map("port" -> "1234").asJava)) } } @@ -190,7 +190,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val provider = new TextSocketSourceProvider val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { - provider.getTable(new DataSourceOptions(params.asJava)) + provider.getTable(new CaseInsensitiveStringMap(params.asJava)) } } @@ -201,10 +201,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") val exception = intercept[UnsupportedOperationException] { - provider.getTable(new DataSourceOptions(params.asJava), userSpecifiedSchema) + provider.getTable(new CaseInsensitiveStringMap(params.asJava), userSpecifiedSchema) } assert(exception.getMessage.contains( - "socket source does not support user-specified schema")) + "TextSocketSourceProvider source does not support user-specified schema")) } test("input row metrics") { @@ -299,7 +299,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before host = "localhost", port = serverThread.port, numPartitions = 2, - options = DataSourceOptions.empty()) + options = CaseInsensitiveStringMap.empty()) val partitions = stream.planInputPartitions(stream.initialOffset()) assert(partitions.length == 2) @@ -351,7 +351,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before host = "localhost", port = serverThread.port, numPartitions = 2, - options = DataSourceOptions.empty()) + options = CaseInsensitiveStringMap.empty()) stream.startOffset = TextSocketOffset(List(5, 5)) assertThrows[IllegalStateException] { @@ -367,7 +367,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before host = "localhost", port = serverThread.port, numPartitions = 2, - options = new DataSourceOptions(Map("includeTimestamp" -> "true").asJava)) + options = new CaseInsensitiveStringMap(Map("includeTimestamp" -> "true").asJava)) val partitions = stream.planInputPartitions(stream.initialOffset()) assert(partitions.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 4592a1663faed..60f1b32a41f05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -21,8 +21,8 @@ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index e8062dbb91e35..4dd65385d548b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala new file mode 100644 index 0000000000000..5b9071b59b9b0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalog.v2.Identifier +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StringType, StructType} + +class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + private val orc2 = classOf[OrcDataSourceV2].getName + + before { + spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) + spark.conf.set("spark.sql.default.catalog", "testcat") + + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + df.createOrReplaceTempView("source") + val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, "f"))).toDF("id", "data") + df2.createOrReplaceTempView("source2") + } + + after { + spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables() + spark.sql("DROP TABLE source") + } + + test("CreateTable: use v2 plan because catalog is set") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) + } + + test("CreateTable: use v2 plan because provider is v2") { + spark.sql(s"CREATE TABLE table_name (id bigint, data string) USING $orc2") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> orc2).asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) + } + + test("CreateTable: fail if table exists") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + val testCatalog = spark.catalog("testcat").asTableCatalog + + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) + + // run a second create query that should fail + val exc = intercept[TableAlreadyExistsException] { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string, id2 bigint) USING bar") + } + + assert(exc.getMessage.contains("table_name")) + + // table should not have changed + val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table2.name == "testcat.table_name") + assert(table2.partitioning.isEmpty) + assert(table2.properties == Map("provider" -> "foo").asJava) + assert(table2.schema == new StructType().add("id", LongType).add("data", StringType)) + + // check that the table is still empty + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) + } + + test("CreateTable: if not exists") { + spark.sql( + "CREATE TABLE IF NOT EXISTS testcat.table_name (id bigint, data string) USING foo") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) + + spark.sql("CREATE TABLE IF NOT EXISTS testcat.table_name (id bigint, data string) USING bar") + + // table should not have changed + val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table2.name == "testcat.table_name") + assert(table2.partitioning.isEmpty) + assert(table2.properties == Map("provider" -> "foo").asJava) + assert(table2.schema == new StructType().add("id", LongType).add("data", StringType)) + + // check that the table is still empty + val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq.empty) + } + + test("CreateTable: fail analysis when default catalog is needed but missing") { + val originalDefaultCatalog = conf.getConfString("spark.sql.default.catalog") + try { + conf.unsetConf("spark.sql.default.catalog") + + val exc = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") + } + + assert(exc.getMessage.contains("No catalog specified for table")) + assert(exc.getMessage.contains("table_name")) + assert(exc.getMessage.contains("no default catalog is set")) + + } finally { + conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog) + } + } + + test("CreateTableAsSelect: use v2 plan because catalog is set") { + spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + } + + test("CreateTableAsSelect: use v2 plan because provider is v2") { + spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> orc2).asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + } + + test("CreateTableAsSelect: fail if table exists") { + spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source") + + val testCatalog = spark.catalog("testcat").asTableCatalog + + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + + // run a second CTAS query that should fail + val exc = intercept[TableAlreadyExistsException] { + spark.sql( + "CREATE TABLE testcat.table_name USING bar AS SELECT id, data, id as id2 FROM source2") + } + + assert(exc.getMessage.contains("table_name")) + + // table should not have changed + val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table2.name == "testcat.table_name") + assert(table2.partitioning.isEmpty) + assert(table2.properties == Map("provider" -> "foo").asJava) + assert(table2.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source")) + } + + test("CreateTableAsSelect: if not exists") { + spark.sql( + "CREATE TABLE IF NOT EXISTS testcat.table_name USING foo AS SELECT id, data FROM source") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + + spark.sql( + "CREATE TABLE IF NOT EXISTS testcat.table_name USING foo AS SELECT id, data FROM source2") + + // check that the table contains data from just the first CTAS + val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source")) + } + + test("CreateTableAsSelect: fail analysis when default catalog is needed but missing") { + val originalDefaultCatalog = conf.getConfString("spark.sql.default.catalog") + try { + conf.unsetConf("spark.sql.default.catalog") + + val exc = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") + } + + assert(exc.getMessage.contains("No catalog specified for table")) + assert(exc.getMessage.contains("table_name")) + assert(exc.getMessage.contains("no default catalog is set")) + + } finally { + conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog) + } + } + + test("DropTable: basic") { + val tableName = "testcat.ns1.ns2.tbl" + val ident = Identifier.of(Array("ns1", "ns2"), "tbl") + sql(s"CREATE TABLE $tableName USING foo AS SELECT id, data FROM source") + assert(spark.catalog("testcat").asTableCatalog.tableExists(ident) === true) + sql(s"DROP TABLE $tableName") + assert(spark.catalog("testcat").asTableCatalog.tableExists(ident) === false) + } + + test("DropTable: if exists") { + intercept[NoSuchTableException] { + sql(s"DROP TABLE testcat.db.notbl") + } + sql(s"DROP TABLE IF EXISTS testcat.db.notbl") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 511fdfe5c23ac..379c9c4303cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -18,21 +18,27 @@ package org.apache.spark.sql.sources.v2 import java.io.File +import java.util +import java.util.OptionalLong + +import scala.collection.JavaConverters._ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { @@ -182,6 +188,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test ("statistics report data source") { + Seq(classOf[ReportStatisticsDataSource], classOf[JavaReportStatisticsDataSource]).foreach { + cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + val logical = df.queryExecution.optimizedPlan.collect { + case d: DataSourceV2Relation => d + }.head + + val statics = logical.computeStats() + assert(statics.rowCount.isDefined && statics.rowCount.get === 10, + "Row count statics should be reported by data source") + assert(statics.sizeInBytes === 80, + "Size in bytes statics should be reported by data source") + } + } + } + test("SPARK-23574: no shuffle exchange with single partition") { val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*")) assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty) @@ -195,14 +219,14 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).save() + .option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).select('id, -'id)) - // test with different save modes + // default save mode is append spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("append").save() + .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).union(spark.range(10)).select('id, -'id)) @@ -213,17 +237,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("ignore").save() - checkAnswer( - spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) + val e = intercept[AnalysisException] { + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) + .option("path", path).mode("ignore").save() + } + assert(e.message.contains("please use Append or Overwrite modes instead")) - val e = intercept[Exception] { + val e2 = intercept[AnalysisException] { spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("error").save() } - assert(e.getMessage.contains("data already exists")) + assert(e2.getMessage.contains("please use Append or Overwrite modes instead")) // test transaction val failingUdf = org.apache.spark.sql.functions.udf { @@ -238,10 +262,10 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } // this input data will fail to read middle way. val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) - val e2 = intercept[SparkException] { + val e3 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } - assert(e2.getMessage.contains("Writing job aborted")) + assert(e3.getMessage.contains("Writing job aborted")) // make sure we don't have partial data. assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) } @@ -330,7 +354,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val options = df.queryExecution.optimizedPlan.collectFirst { case d: DataSourceV2Relation => d.options }.get - assert(options.get(optionName).get == "false") + assert(options.get(optionName) === "false") } } @@ -351,19 +375,16 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("SPARK-25700: do not read schema when writing in other modes except append mode") { - withTempPath { file => - val cls = classOf[SimpleWriteOnlyDataSource] - val path = file.getCanonicalPath - val df = spark.range(5).select('id as 'i, -'id as 'j) - // non-append mode should not throw exception, as they don't access schema. - df.write.format(cls.getName).option("path", path).mode("error").save() - df.write.format(cls.getName).option("path", path).mode("overwrite").save() - df.write.format(cls.getName).option("path", path).mode("ignore").save() - // append mode will access schema and should throw exception. - intercept[SchemaReadAttemptException] { - df.write.format(cls.getName).option("path", path).mode("append").save() - } + test("SPARK-27411: DataSourceV2Strategy should not eliminate subquery") { + withTempView("t1") { + val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + Seq(2, 3).toDF("a").createTempView("t1") + val df = t2.where("i < (select max(a) from t1)").select('i) + val subqueries = df.queryExecution.executedPlan.collect { + case p => p.subqueries + }.flatten + assert(subqueries.length == 1) + checkAnswer(df, (0 until 3).map(i => Row(i))) } } } @@ -389,11 +410,13 @@ object SimpleReaderFactory extends PartitionReaderFactory { } } -abstract class SimpleBatchTable extends Table with SupportsBatchRead { +abstract class SimpleBatchTable extends Table with SupportsRead { override def schema(): StructType = new StructType().add("i", "int").add("j", "int") override def name(): String = this.getClass.toString + + override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava } abstract class SimpleScanBuilder extends ScanBuilder @@ -416,8 +439,8 @@ class SimpleSinglePartitionSource extends TableProvider { } } - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } @@ -433,8 +456,8 @@ class SimpleDataSourceV2 extends TableProvider { } } - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } @@ -442,8 +465,8 @@ class SimpleDataSourceV2 extends TableProvider { class AdvancedDataSourceV2 extends TableProvider { - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new AdvancedScanBuilder() } } @@ -538,16 +561,16 @@ class SchemaRequiredDataSource extends TableProvider { override def readSchema(): StructType = schema } - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def getTable(options: DataSourceOptions, schema: StructType): Table = { + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder(userGivenSchema) } } @@ -567,8 +590,8 @@ class ColumnarDataSourceV2 extends TableProvider { } } - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } @@ -619,7 +642,6 @@ object ColumnarReaderFactory extends PartitionReaderFactory { } } - class PartitionAwareDataSource extends TableProvider { class MyScanBuilder extends SimpleScanBuilder @@ -639,8 +661,8 @@ class PartitionAwareDataSource extends TableProvider { override def outputPartitioning(): Partitioning = new MyPartitioning } - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } @@ -679,7 +701,7 @@ class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new MyTable(options) { override def schema(): StructType = { throw new SchemaReadAttemptException("schema should not be read.") @@ -687,3 +709,29 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { } } } + +class ReportStatisticsDataSource extends TableProvider { + + class MyScanBuilder extends SimpleScanBuilder + with SupportsReportStatistics { + override def estimateStatistics(): Statistics = { + new Statistics { + override def sizeInBytes(): OptionalLong = OptionalLong.of(80) + + override def numRows(): OptionalLong = OptionalLong.of(10) + } + } + + override def planInputPartitions(): Array[InputPartition] = { + Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) + } + } + + override def getTable(options: CaseInsensitiveStringMap): Table = { + new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new MyScanBuilder + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index f903c17923d0f..0b1e3b5fb076d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -33,8 +33,8 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") conf.setConfString("spark.datasource.another.config.name", "123") conf.setConfString(s"spark.datasource.$keyPrefix.", "123") - val cs = classOf[DataSourceV2WithSessionConfig].getConstructor().newInstance() - val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) + val source = new DataSourceV2WithSessionConfig + val confs = DataSourceV2Utils.extractSessionConfigs(source, conf) assert(confs.size == 2) assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala index fd19a48497fe6..e84c082128e1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -16,15 +16,18 @@ */ package org.apache.spark.sql.sources.v2 +import scala.collection.JavaConverters._ + import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetTest} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.ScanBuilder import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { @@ -32,19 +35,22 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new DummyReadOnlyFileTable } } -class DummyReadOnlyFileTable extends Table with SupportsBatchRead { +class DummyReadOnlyFileTable extends Table with SupportsRead { override def name(): String = "dummy" override def schema(): StructType = StructType(Nil) - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { throw new AnalysisException("Dummy file reader") } + + override def capabilities(): java.util.Set[TableCapability] = + Set(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA).asJava } class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { @@ -53,18 +59,21 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new DummyWriteOnlyFileTable } } -class DummyWriteOnlyFileTable extends Table with SupportsBatchWrite { +class DummyWriteOnlyFileTable extends Table with SupportsWrite { override def name(): String = "dummy" override def schema(): StructType = StructType(Nil) - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = throw new AnalysisException("Dummy file writer") + + override def capabilities(): java.util.Set[TableCapability] = + Set(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava } class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index daca65fd1ad2c..c9d2f1eef24bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} +import java.util import scala.collection.JavaConverters._ @@ -25,12 +26,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext -import org.apache.spark.internal.config.SPECULATION_ENABLED -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration /** @@ -38,8 +39,7 @@ import org.apache.spark.util.SerializableConfiguration * Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`. * Each job moves files from `target/_temporary/uniqueId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 - with TableProvider with SessionConfigSupport { +class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { private val tableSchema = new StructType().add("i", "long").add("j", "long") @@ -69,38 +69,26 @@ class SimpleWritableDataSource extends DataSourceV2 override def readSchema(): StructType = tableSchema } - class MyWriteBuilder(path: String) extends WriteBuilder with SupportsSaveMode { + class MyWriteBuilder(path: String) extends WriteBuilder with SupportsTruncate { private var queryId: String = _ - private var mode: SaveMode = _ + private var needTruncate = false override def withQueryId(queryId: String): WriteBuilder = { this.queryId = queryId this } - override def mode(mode: SaveMode): WriteBuilder = { - this.mode = mode + override def truncate(): WriteBuilder = { + this.needTruncate = true this } override def buildForBatch(): BatchWrite = { - assert(mode != null) - val hadoopPath = new Path(path) val hadoopConf = SparkContext.getActive.get.hadoopConfiguration val fs = hadoopPath.getFileSystem(hadoopConf) - if (mode == SaveMode.ErrorIfExists) { - if (fs.exists(hadoopPath)) { - throw new RuntimeException("data already exists.") - } - } - if (mode == SaveMode.Ignore) { - if (fs.exists(hadoopPath)) { - return null - } - } - if (mode == SaveMode.Overwrite) { + if (needTruncate) { fs.delete(hadoopPath, true) } @@ -142,22 +130,27 @@ class SimpleWritableDataSource extends DataSourceV2 } } - class MyTable(options: DataSourceOptions) extends SimpleBatchTable with SupportsBatchWrite { - private val path = options.get("path").get() + class MyTable(options: CaseInsensitiveStringMap) + extends SimpleBatchTable with SupportsWrite { + + private val path = options.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration override def schema(): StructType = tableSchema - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder(new Path(path).toUri.toString, conf) } - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new MyWriteBuilder(path) } + + override def capabilities(): util.Set[TableCapability] = + Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava } - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new MyTable(options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala new file mode 100644 index 0000000000000..42c2db2539060 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.sql.catalog.v2.{CatalogV2Implicits, Identifier, TableCatalog, TableChange, TestTableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// this is currently in the spark-sql module because the read and write API is not in catalyst +// TODO(rdblue): when the v2 source API is in catalyst, merge with TestTableCatalog/InMemoryTable +class TestInMemoryTableCatalog extends TableCatalog { + import CatalogV2Implicits._ + + private val tables: util.Map[Identifier, InMemoryTable] = + new ConcurrentHashMap[Identifier, InMemoryTable]() + private var _name: Option[String] = None + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + _name = Some(name) + } + + override def name: String = _name.get + + override def listTables(namespace: Array[String]): Array[Identifier] = { + tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray + } + + override def loadTable(ident: Identifier): Table = { + Option(tables.get(ident)) match { + case Some(table) => + table + case _ => + throw new NoSuchTableException(ident) + } + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + if (partitions.nonEmpty) { + throw new UnsupportedOperationException( + s"Catalog $name: Partitioned tables are not supported") + } + + val table = new InMemoryTable(s"$name.${ident.quoted}", schema, properties) + + tables.put(ident, table) + + table + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + Option(tables.get(ident)) match { + case Some(table) => + val properties = TestTableCatalog.applyPropertiesChanges(table.properties, changes) + val schema = TestTableCatalog.applySchemaChanges(table.schema, changes) + val newTable = new InMemoryTable(table.name, schema, properties, table.data) + + tables.put(ident, newTable) + + newTable + case _ => + throw new NoSuchTableException(ident) + } + } + + override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined + + def clearTables(): Unit = { + tables.clear() + } +} + +/** + * A simple in-memory table. Rows are stored as a buffered group produced by each output task. + */ +private class InMemoryTable( + val name: String, + val schema: StructType, + override val properties: util.Map[String, String]) + extends Table with SupportsRead with SupportsWrite { + + def this( + name: String, + schema: StructType, + properties: util.Map[String, String], + data: Array[BufferedRows]) = { + this(name, schema, properties) + replaceData(data) + } + + def rows: Seq[InternalRow] = data.flatMap(_.rows) + + @volatile var data: Array[BufferedRows] = Array.empty + + def replaceData(buffers: Array[BufferedRows]): Unit = synchronized { + data = buffers + } + + override def capabilities: util.Set[TableCapability] = Set( + TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE).asJava + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new ScanBuilder() { + def build(): Scan = new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition])) + } + } + + class InMemoryBatchScan(data: Array[InputPartition]) extends Scan with Batch { + override def readSchema(): StructType = schema + + override def toBatch: Batch = this + + override def planInputPartitions(): Array[InputPartition] = data + + override def createReaderFactory(): PartitionReaderFactory = BufferedRowsReaderFactory + } + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + new WriteBuilder with SupportsTruncate { + private var shouldTruncate: Boolean = false + + override def truncate(): WriteBuilder = { + shouldTruncate = true + this + } + + override def buildForBatch(): BatchWrite = { + if (shouldTruncate) TruncateAndAppend else Append + } + } + } + + private object TruncateAndAppend extends BatchWrite { + override def createBatchWriterFactory(): DataWriterFactory = { + BufferedRowsWriterFactory + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + replaceData(messages.map(_.asInstanceOf[BufferedRows])) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + } + } + + private object Append extends BatchWrite { + override def createBatchWriterFactory(): DataWriterFactory = { + BufferedRowsWriterFactory + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + replaceData(data ++ messages.map(_.asInstanceOf[BufferedRows])) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + } + } +} + +private class BufferedRows extends WriterCommitMessage with InputPartition with Serializable { + val rows = new mutable.ArrayBuffer[InternalRow]() +} + +private object BufferedRowsReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + new BufferedRowsReader(partition.asInstanceOf[BufferedRows]) + } +} + +private class BufferedRowsReader(partition: BufferedRows) extends PartitionReader[InternalRow] { + private var index: Int = -1 + + override def next(): Boolean = { + index += 1 + index < partition.rows.length + } + + override def get(): InternalRow = partition.rows(index) + + override def close(): Unit = {} +} + +private object BufferedRowsWriterFactory extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new BufferWriter + } +} + +private class BufferWriter extends DataWriter[InternalRow] { + private val buffer = new BufferedRows + + override def write(row: InternalRow): Unit = buffer.rows.append(row.copy()) + + override def commit(): WriterCommitMessage = buffer + + override def abort(): Unit = {} +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala new file mode 100644 index 0000000000000..1d76ee34a0e0b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, V2WriteSupportCheck} +import org.apache.spark.sql.sources.v2.TableCapability._ +import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class V2WriteSupportCheckSuite extends AnalysisTest { + + test("AppendData: check missing capabilities") { + val plan = AppendData.byName( + DataSourceV2Relation.create(CapabilityTable(), CaseInsensitiveStringMap.empty), TestRelation) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support append in batch mode")) + } + + test("AppendData: check correct capabilities") { + val plan = AppendData.byName( + DataSourceV2Relation.create(CapabilityTable(BATCH_WRITE), CaseInsensitiveStringMap.empty), + TestRelation) + + V2WriteSupportCheck.apply(plan) + } + + test("Truncate: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(TRUNCATE), + CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + Literal(true)) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support truncate in batch mode")) + } + } + + test("Truncate: check correct capabilities") { + Seq(CapabilityTable(BATCH_WRITE, TRUNCATE), + CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + Literal(true)) + + V2WriteSupportCheck.apply(plan) + } + } + + test("OverwriteByExpression: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + EqualTo(AttributeReference("x", LongType)(), Literal(5))) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains( + "does not support overwrite expression (`x` = 5) in batch mode")) + } + } + + test("OverwriteByExpression: check correct capabilities") { + val table = CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER) + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + EqualTo(AttributeReference("x", LongType)(), Literal(5))) + + V2WriteSupportCheck.apply(plan) + } + + test("OverwritePartitionsDynamic: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(OVERWRITE_DYNAMIC)).foreach { table => + + val plan = OverwritePartitionsDynamic.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation) + + val exc = intercept[AnalysisException] { + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support dynamic overwrite in batch mode")) + } + } + + test("OverwritePartitionsDynamic: check correct capabilities") { + val table = CapabilityTable(BATCH_WRITE, OVERWRITE_DYNAMIC) + val plan = OverwritePartitionsDynamic.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation) + + V2WriteSupportCheck.apply(plan) + } +} + +private object V2WriteSupportCheckSuite { + val schema: StructType = new StructType().add("id", LongType).add("data", StringType) +} + +private case object TestRelation extends LeafNode with NamedRelation { + override def name: String = "source_relation" + override def output: Seq[AttributeReference] = V2WriteSupportCheckSuite.schema.toAttributes +} + +private case class CapabilityTable(_capabilities: TableCapability*) extends Table { + override def name(): String = "capability_test_table" + override def schema(): StructType = V2WriteSupportCheckSuite.schema + override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c696204cecc2c..a0a55c08ff018 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index ed53def556cb8..619d118e20873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.Locale import org.apache.hadoop.fs.Path @@ -454,4 +455,27 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("special characters in output path") { + withTempDir { tempDir => + val checkpointDir = new File(tempDir, "chk") + val outputDir = new File(tempDir, "output @#output") + val inputData = MemoryStream[Int] + inputData.addData(1, 2, 3) + val q = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("parquet") + .start(outputDir.getCanonicalPath) + try { + q.processAllAvailable() + } finally { + q.stop() + } + // The "_spark_metadata" directory should be in "outputDir" + assert(outputDir.listFiles.map(_.getName).contains(FileStreamSink.metadataDir)) + val outputDf = spark.read.parquet(outputDir.getCanonicalPath).as[Int] + checkDatasetUnorderly(outputDf, 1, 2, 3) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 9235c6d7c896f..0736c6ef00eed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} +import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ import org.apache.spark.sql.streaming.util.StreamManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 659deb8cbb51e..f229b08a20aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SparkConf, SparkContext, TaskContext, TestUtils} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.Range @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, MemorySink} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -876,8 +876,8 @@ class StreamSuite extends StreamTest { query.awaitTermination() } - assert(e.getMessage.contains(providerClassName)) - assert(e.getMessage.contains("instantiated")) + TestUtils.assertExceptionMsg(e, providerClassName) + TestUtils.assertExceptionMsg(e, "instantiated") } } @@ -1083,15 +1083,15 @@ class StreamSuite extends StreamTest { test("SPARK-26379 Structured Streaming - Exception on adding current_timestamp " + " to Dataset - use v2 sink") { - testCurrentTimestampOnStreamingQuery(useV2Sink = true) + testCurrentTimestampOnStreamingQuery() } test("SPARK-26379 Structured Streaming - Exception on adding current_timestamp " + " to Dataset - use v1 sink") { - testCurrentTimestampOnStreamingQuery(useV2Sink = false) + testCurrentTimestampOnStreamingQuery() } - private def testCurrentTimestampOnStreamingQuery(useV2Sink: Boolean): Unit = { + private def testCurrentTimestampOnStreamingQuery(): Unit = { val input = MemoryStream[Int] val df = input.toDS().withColumn("cur_timestamp", lit(current_timestamp())) @@ -1109,7 +1109,7 @@ class StreamSuite extends StreamTest { var lastTimestamp = System.currentTimeMillis() val currentDate = DateTimeUtils.millisToDays(lastTimestamp) - testStream(df, useV2Sink = useV2Sink) ( + testStream(df) ( AddData(input, 1), CheckLastBatch { rows: Seq[Row] => lastTimestamp = assertBatchOutputAndUpdateLastTimestamp(rows, lastTimestamp, currentDate, 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index da496837e7a19..fc72c940b922a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -42,8 +42,9 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 +import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -86,7 +87,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } protected val defaultTrigger = Trigger.ProcessingTime(0) - protected val defaultUseV2Sink = false /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -126,7 +126,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * the active query, and then return the source object the data was added, as well as the * offset of added data. */ - def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) + def addData(query: Option[StreamExecution]): (SparkDataStream, OffsetV2) } /** A trait that can be extended when testing a source. */ @@ -137,7 +137,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, OffsetV2) = { (source, source.addData(data)) } } @@ -294,7 +294,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { def apply(name: String)(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }, "name") + AssertOnQuery(query => { func(query); true }, name) def apply(func: StreamExecution => Any): AssertOnQuery = apply("Execute")(func) } @@ -327,8 +327,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be */ def testStream( _stream: Dataset[_], - outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { + outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -340,8 +339,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var pos = 0 var currentStream: StreamExecution = null var lastStream: StreamExecution = null - val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) + val awaiting = new mutable.HashMap[Int, OffsetV2]() // source index -> offset to wait for + val sink = new MemorySink val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -394,10 +393,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } def testState = { - val sinkDebugString = sink match { - case s: MemorySink => s.toDebugString - case s: MemorySinkV2 => s.toDebugString - } + val sinkDebugString = sink.toDebugString + s""" |== Progress == |$testActions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 97dbb9b0360ec..3f304e9ec7788 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.{Locale, TimeZone} import org.apache.commons.io.FileUtils -import org.scalatest.{Assertions, BeforeAndAfterAll} +import org.scalatest.Assertions import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.rdd.BlockRDD @@ -32,7 +32,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} +import org.apache.spark.sql.execution.streaming.sources.MemorySink +import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index d00f2e3bf4d1a..5351d9cf7f190 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -180,7 +180,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val listeners = (1 to 5).map(_ => new EventCollector) try { listeners.foreach(listener => spark.streams.addListener(listener)) - testStream(df, OutputMode.Append, useV2Sink = true)( + testStream(df, OutputMode.Append)( StartStream(Trigger.Continuous(1000)), StopStream, AssertOnQuery { query => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index dc22e31678fa3..ec0be40528a45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,23 +17,26 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.concurrent.CountDownLatch import scala.collection.mutable +import org.apache.commons.io.FileUtils import org.apache.commons.lang3.RandomStringUtils +import org.apache.hadoop.fs.Path import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.mockito.MockitoSugar -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter +import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.InputPartition @@ -495,7 +498,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi test("input row calculation with same V2 source used twice in self-union") { val streamInput = MemoryStream[Int] - testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + testStream(streamInput.toDF().union(streamInput.toDF()))( AddData(streamInput, 1, 2, 3), CheckAnswer(1, 1, 2, 2, 3, 3), AssertOnQuery { q => @@ -516,7 +519,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi // relation, which breaks exchange reuse, as the optimizer will remove Project from one side. // Here we manually add a useful Project, to trigger exchange reuse. val streamDF = memoryStream.toDF().select('value + 0 as "v") - testStream(streamDF.join(streamDF, "v"), useV2Sink = true)( + testStream(streamDF.join(streamDF, "v"))( AddData(memoryStream, 1, 2, 3), CheckAnswer(1, 2, 3), check @@ -553,7 +556,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val streamInput1 = MemoryStream[Int] val streamInput2 = MemoryStream[Int] - testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + testStream(streamInput1.toDF().union(streamInput2.toDF()))( AddData(streamInput1, 1, 2, 3), CheckLastBatch(1, 2, 3), AssertOnQuery { q => @@ -584,7 +587,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val streamInput = MemoryStream[Int] val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") - testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + testStream(streamInput.toDF().join(staticInputDF, "value"))( AddData(streamInput, 1, 2, 3), AssertOnQuery { q => q.processAllAvailable() @@ -606,7 +609,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val streamInput2 = MemoryStream[Int] val staticInputDF2 = staticInputDF.union(staticInputDF).cache() - testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + testStream(streamInput2.toDF().join(staticInputDF2, "value"))( AddData(streamInput2, 1, 2, 3), AssertOnQuery { q => q.processAllAvailable() @@ -714,8 +717,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi q3.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.isInstanceOf[IllegalStateException]) - assert(e.getMessage.contains("StreamingQuery cannot be used in executors")) + assert(e.getCause.getCause.getCause.isInstanceOf[IllegalStateException]) + TestUtils.assertExceptionMsg(e, "StreamingQuery cannot be used in executors") } finally { q1.stop() q2.stop() @@ -909,12 +912,195 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation")) ) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(trigger = Trigger.Continuous(100)), AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation")) ) } + test("special characters in checkpoint path") { + withTempDir { tempDir => + val checkpointDir = new File(tempDir, "chk @#chk") + val inputData = MemoryStream[Int] + inputData.addData(1) + val q = inputData.toDF() + .writeStream + .format("noop") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start() + try { + q.processAllAvailable() + assert(checkpointDir.listFiles().toList.nonEmpty) + } finally { + q.stop() + } + } + } + + /** + * Copy the checkpoint generated by Spark 2.4.0 from test resource to `dir` to set up a legacy + * streaming checkpoint. + */ + private def setUp2dot4dot0Checkpoint(dir: File): Unit = { + val input = getClass.getResource("/structured-streaming/escaped-path-2.4.0") + assert(input != null, "cannot find test resource '/structured-streaming/escaped-path-2.4.0'") + val inputDir = new File(input.toURI) + + // Copy test files to tempDir so that we won't modify the original data. + FileUtils.copyDirectory(inputDir, dir) + + // Spark 2.4 and earlier escaped the _spark_metadata path once + val legacySparkMetadataDir = new File( + dir, + new Path("output %@#output/_spark_metadata").toUri.toString) + + // Migrate from legacy _spark_metadata directory to the new _spark_metadata directory. + // Ideally we should copy "_spark_metadata" directly like what the user is supposed to do to + // migrate to new version. However, in our test, "tempDir" will be different in each run and + // we need to fix the absolute path in the metadata to match "tempDir". + val sparkMetadata = FileUtils.readFileToString(new File(legacySparkMetadataDir, "0"), "UTF-8") + FileUtils.write( + new File(legacySparkMetadataDir, "0"), + sparkMetadata.replaceAll("TEMPDIR", dir.getCanonicalPath), + "UTF-8") + } + + test("detect escaped path and report the migration guide") { + // Assert that the error message contains the migration conf, path and the legacy path. + def assertMigrationError(errorMessage: String, path: File, legacyPath: File): Unit = { + Seq(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key, + path.getCanonicalPath, + legacyPath.getCanonicalPath).foreach { msg => + assert(errorMessage.contains(msg)) + } + } + + withTempDir { tempDir => + setUp2dot4dot0Checkpoint(tempDir) + + // Here are the paths we will use to create the query + val outputDir = new File(tempDir, "output %@#output") + val checkpointDir = new File(tempDir, "chk %@#chk") + val sparkMetadataDir = new File(tempDir, "output %@#output/_spark_metadata") + + // The escaped paths used by Spark 2.4 and earlier. + // Spark 2.4 and earlier escaped the checkpoint path three times + val legacyCheckpointDir = new File( + tempDir, + new Path(new Path(new Path("chk %@#chk").toUri.toString).toUri.toString).toUri.toString) + // Spark 2.4 and earlier escaped the _spark_metadata path once + val legacySparkMetadataDir = new File( + tempDir, + new Path("output %@#output/_spark_metadata").toUri.toString) + + // Reading a file sink output in a batch query should detect the legacy _spark_metadata + // directory and throw an error + val e = intercept[SparkException] { + spark.read.load(outputDir.getCanonicalPath).as[Int] + } + assertMigrationError(e.getMessage, sparkMetadataDir, legacySparkMetadataDir) + + // Restarting the streaming query should detect the legacy _spark_metadata directory and throw + // an error + val inputData = MemoryStream[Int] + val e2 = intercept[SparkException] { + inputData.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(outputDir.getCanonicalPath) + } + assertMigrationError(e2.getMessage, sparkMetadataDir, legacySparkMetadataDir) + + // Move "_spark_metadata" to fix the file sink and test the checkpoint path. + FileUtils.moveDirectory(legacySparkMetadataDir, sparkMetadataDir) + + // Restarting the streaming query should detect the legacy checkpoint path and throw an error + val e3 = intercept[SparkException] { + inputData.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(outputDir.getCanonicalPath) + } + assertMigrationError(e3.getMessage, checkpointDir, legacyCheckpointDir) + + // Fix the checkpoint path and verify that the user can migrate the issue by moving files. + FileUtils.moveDirectory(legacyCheckpointDir, checkpointDir) + + val q = inputData.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(outputDir.getCanonicalPath) + try { + q.processAllAvailable() + // Check the query id to make sure it did use checkpoint + assert(q.id.toString == "09be7fb3-49d8-48a6-840d-e9c2ad92a898") + + // Verify that the batch query can read "_spark_metadata" correctly after migration. + val df = spark.read.load(outputDir.getCanonicalPath) + assert(df.queryExecution.executedPlan.toString contains "MetadataLogFileIndex") + checkDatasetUnorderly(df.as[Int], 1, 2, 3) + } finally { + q.stop() + } + } + } + + test("ignore the escaped path check when the flag is off") { + withTempDir { tempDir => + setUp2dot4dot0Checkpoint(tempDir) + val outputDir = new File(tempDir, "output %@#output") + val checkpointDir = new File(tempDir, "chk %@#chk") + + withSQLConf(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key -> "false") { + // Verify that the batch query ignores the legacy "_spark_metadata" + val df = spark.read.load(outputDir.getCanonicalPath) + assert(!(df.queryExecution.executedPlan.toString contains "MetadataLogFileIndex")) + checkDatasetUnorderly(df.as[Int], 1, 2, 3) + + val inputData = MemoryStream[Int] + val q = inputData.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(outputDir.getCanonicalPath) + try { + q.processAllAvailable() + // Check the query id to make sure it ignores the legacy checkpoint + assert(q.id.toString != "09be7fb3-49d8-48a6-840d-e9c2ad92a898") + } finally { + q.stop() + } + } + } + } + + test("containsSpecialCharsInPath") { + Seq("foo/b ar", + "/foo/b ar", + "file:/foo/b ar", + "file://foo/b ar", + "file:///foo/b ar", + "file://foo:bar@bar/foo/b ar").foreach { p => + assert(StreamExecution.containsSpecialCharsInPath(new Path(p)), s"failed to check $p") + } + Seq("foo/bar", + "/foo/bar", + "file:/foo/bar", + "file://foo/bar", + "file:///foo/bar", + "file://foo:bar@bar/foo/bar", + // Special chars not in a path should not be considered as such urls won't hit the escaped + // path issue. + "file://foo:b ar@bar/foo/bar", + "file://foo:bar@b ar/foo/bar", + "file://f oo:bar@bar/foo/bar").foreach { p => + assert(!StreamExecution.containsSpecialCharsInPath(new Path(p)), s"failed to check $p") + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala index 10bea7f090571..59d6ac0af52a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala @@ -34,7 +34,7 @@ class ContinuousQueryStatusAndProgressSuite extends ContinuousSuiteBase { } val trigger = Trigger.Continuous(100) - testStream(input.toDF(), useV2Sink = true)( + testStream(input.toDF())( StartStream(trigger), Execute(assertStatus), AddData(input, 0, 1, 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index d3d210c02e90d..bad22590807a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousStream, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StructType} @@ -43,7 +43,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamingWriteSupport], + mock[StreamingWrite], mock[ContinuousStream], mock[ContinuousExecution], coordinatorId, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 344a8aa55f0c5..9840c7f066780 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -56,7 +57,6 @@ class ContinuousSuiteBase extends StreamTest { protected val longContinuousTrigger = Trigger.Continuous("1 hour") override protected val defaultTrigger = Trigger.Continuous(100) - override protected val defaultUseV2Sink = true } class ContinuousSuite extends ContinuousSuiteBase { @@ -238,7 +238,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .load() .select('value) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(longContinuousTrigger), AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 10)), @@ -256,7 +256,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .load() .select('value) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(2012)), AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 10)), @@ -273,7 +273,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .load() .select('value) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(1012)), AwaitEpoch(2), StopStream, @@ -343,3 +343,33 @@ class ContinuousMetaSuite extends ContinuousSuiteBase { } } } + +class ContinuousEpochBacklogSuite extends ContinuousSuiteBase { + import testImplicits._ + + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[1]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // This test forces the backlog to overflow by not standing up enough executors for the query + // to make progress. + test("epoch backlog overflow") { + withSQLConf((CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE.key, "10")) { + val df = spark.readStream + .format("rate") + .option("numPartitions", "2") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df)( + StartStream(Trigger.Continuous(1)), + ExpectFailure[IllegalStateException] { e => + e.getMessage.contains("queue has exceeded its maximum") + } + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index a0b56ec17f0be..e3498db4194e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.streaming.continuous +import org.mockito.{ArgumentCaptor, InOrder} import org.mockito.ArgumentMatchers.{any, eq => eqTo} -import org.mockito.InOrder -import org.mockito.Mockito.{inOrder, never, verify} +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfterEach import org.scalatest.mockito.MockitoSugar @@ -27,9 +27,10 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,17 +41,22 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writeSupport: StreamingWriteSupport = _ + private var writeSupport: StreamingWrite = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ + private val epochBacklogQueueSize = 10 override def beforeEach(): Unit = { val stream = mock[ContinuousStream] - writeSupport = mock[StreamingWriteSupport] + writeSupport = mock[StreamingWrite] query = mock[ContinuousExecution] orderVerifier = inOrder(writeSupport, query) - spark = new TestSparkSession() + spark = new TestSparkSession( + new SparkContext( + "local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true") + .set(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE, epochBacklogQueueSize))) epochCoordinator = EpochCoordinatorRef.create(writeSupport, stream, query, "test", 1, spark, SparkEnv.get) @@ -186,6 +192,66 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4, 5)) } + test("several epochs, max epoch backlog reached by partitionOffsets") { + setWriterPartitions(1) + setReaderPartitions(1) + + reportPartitionOffset(0, 1) + // Commit messages not arriving + for (i <- 2 to epochBacklogQueueSize + 1) { + reportPartitionOffset(0, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 1) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the partition offset queue has exceeded its maximum") + } + + test("several epochs, max epoch backlog reached by partitionCommits") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + // Offset messages not arriving + for (i <- 2 to epochBacklogQueueSize + 1) { + commitPartitionEpoch(0, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 1) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the partition commit queue has exceeded its maximum") + } + + test("several epochs, max epoch backlog reached by epochsWaitingToBeCommitted") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + // For partition 2 epoch 1 messages never arriving + // +2 because the first epoch not yet arrived + for (i <- 2 to epochBacklogQueueSize + 2) { + commitPartitionEpoch(0, i) + reportPartitionOffset(0, i) + commitPartitionEpoch(1, i) + reportPartitionOffset(1, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 2) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the epoch queue has exceeded its maximum") + } + private def setWriterPartitions(numPartitions: Int): Unit = { epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions)) } @@ -221,4 +287,13 @@ class EpochCoordinatorSuite private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = { epochs.foreach(verifyCommit) } + + private def verifyStoppedWithException(msg: String): Unit = { + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]); + verify(query, atLeastOnce()).stopInNewThread(exceptionCaptor.capture()) + + import scala.collection.JavaConverters._ + val throwable = exceptionCaptor.getAllValues.asScala.find(_.getMessage === msg) + assert(throwable != null, "Stream stopped with an exception but expected message is missing") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 62f166602941c..7b2c1a56e8baa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.streaming.sources +import java.util +import java.util.Collections + +import scala.collection.JavaConverters._ + import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} @@ -24,11 +29,14 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.{WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils class FakeDataStream extends MicroBatchStream with ContinuousStream { @@ -59,26 +67,27 @@ class FakeScanBuilder extends ScanBuilder with Scan { override def toContinuousStream(checkpointLocation: String): ContinuousStream = new FakeDataStream } -trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead { - override def name(): String = "fake" - override def schema(): StructType = StructType(Seq()) - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder +class FakeWriteBuilder extends WriteBuilder with StreamingWrite { + override def buildForStreaming(): StreamingWrite = this + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + throw new IllegalStateException("fake sink - cannot actually write") + } + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + throw new IllegalStateException("fake sink - cannot actually write") + } + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + throw new IllegalStateException("fake sink - cannot actually write") + } } -trait FakeContinuousReadTable extends Table with SupportsContinuousRead { +trait FakeStreamingWriteTable extends Table with SupportsWrite { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder -} - -trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - LastWriteOptions.options = options - throw new IllegalStateException("fake sink - cannot actually write") + override def capabilities(): util.Set[TableCapability] = { + Set(STREAMING_WRITE).asJava + } + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + new FakeWriteBuilder } } @@ -90,9 +99,18 @@ class FakeReadMicroBatchOnly override def keyPrefix: String = shortName() - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { LastReadOptions.options = options - new FakeMicroBatchReadTable {} + new Table with SupportsRead { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = { + Set(MICRO_BATCH_READ).asJava + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new FakeScanBuilder + } + } } } @@ -104,45 +122,78 @@ class FakeReadContinuousOnly override def keyPrefix: String = shortName() - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { LastReadOptions.options = options - new FakeContinuousReadTable {} + new Table with SupportsRead { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = { + Set(CONTINUOUS_READ).asJava + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new FakeScanBuilder + } + } } } class FakeReadBothModes extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-microbatch-continuous" - override def getTable(options: DataSourceOptions): Table = { - new Table with FakeMicroBatchReadTable with FakeContinuousReadTable {} + override def getTable(options: CaseInsensitiveStringMap): Table = { + new Table with SupportsRead { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = { + Set(MICRO_BATCH_READ, CONTINUOUS_READ).asJava + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new FakeScanBuilder + } + } } } class FakeReadNeitherMode extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-neither-mode" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() } } } -class FakeWriteSupportProvider +class FakeWriteOnly extends DataSourceRegister - with FakeStreamingWriteSupportProvider + with TableProvider with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" override def keyPrefix: String = shortName() + + override def getTable(options: CaseInsensitiveStringMap): Table = { + LastWriteOptions.options = options + new Table with FakeStreamingWriteTable { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } -class FakeNoWrite extends DataSourceRegister { +class FakeNoWrite extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-write-neither-mode" + override def getTable(options: CaseInsensitiveStringMap): Table = { + new Table { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + } + } } - case class FakeWriteV1FallbackException() extends Exception class FakeSink extends Sink { @@ -150,21 +201,28 @@ class FakeSink extends Sink { } class FakeWriteSupportProviderV1Fallback extends DataSourceRegister - with FakeStreamingWriteSupportProvider with StreamSinkProvider { + with TableProvider with StreamSinkProvider { override def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { new FakeSink() } override def shortName(): String = "fake-write-v1-fallback" + + override def getTable(options: CaseInsensitiveStringMap): Table = { + new Table with FakeStreamingWriteTable { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } object LastReadOptions { - var options: DataSourceOptions = _ + var options: CaseInsensitiveStringMap = _ def clear(): Unit = { options = null @@ -172,7 +230,7 @@ object LastReadOptions { } object LastWriteOptions { - var options: DataSourceOptions = _ + var options: CaseInsensitiveStringMap = _ def clear(): Unit = { options = null @@ -260,7 +318,7 @@ class StreamingDataSourceV2Suite extends StreamTest { testPositiveCaseWithQuery( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v2Query => assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + .isInstanceOf[Table]) } // Ensure we create a V1 sink with the config. Note the config is a comma separated @@ -289,8 +347,8 @@ class StreamingDataSourceV2Suite extends StreamTest { testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ => eventually(timeout(streamingTimeout)) { // Write options should not be set. - assert(LastWriteOptions.options.getBoolean(readOptionName, false) == false) - assert(LastReadOptions.options.getBoolean(readOptionName, false) == true) + assert(!LastWriteOptions.options.containsKey(readOptionName)) + assert(LastReadOptions.options.getBoolean(readOptionName, false)) } } } @@ -300,8 +358,8 @@ class StreamingDataSourceV2Suite extends StreamTest { testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ => eventually(timeout(streamingTimeout)) { // Read options should not be set. - assert(LastReadOptions.options.getBoolean(writeOptionName, false) == false) - assert(LastWriteOptions.options.getBoolean(writeOptionName, false) == true) + assert(!LastReadOptions.options.containsKey(writeOptionName)) + assert(LastWriteOptions.options.getBoolean(writeOptionName, false)) } } } @@ -319,44 +377,43 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { - val table = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) - val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf). - getConstructor().newInstance() - - (table, writeSource, trigger) match { - // Valid microbatch queries. - case (_: SupportsMicroBatchRead, _: StreamingWriteSupportProvider, t) - if !t.isInstanceOf[ContinuousTrigger] => - testPositiveCase(read, write, trigger) - - // Valid continuous queries. - case (_: SupportsContinuousRead, _: StreamingWriteSupportProvider, - _: ContinuousTrigger) => - testPositiveCase(read, write, trigger) + val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() + .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) + + val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() + .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ + trigger match { // Invalid - can't read at all - case (r, _, _) if !r.isInstanceOf[SupportsMicroBatchRead] && - !r.isInstanceOf[SupportsContinuousRead] => + case _ if !sourceTable.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => + case _ if !sinkTable.supports(STREAMING_WRITE) => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") - // Invalid - trigger is continuous but reader is not - case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) - if !r.isInstanceOf[SupportsContinuousRead] => - testNegativeCase(read, write, trigger, - s"Data source $read does not support continuous processing") + case _: ContinuousTrigger => + if (sourceTable.supports(CONTINUOUS_READ)) { + // Valid microbatch queries. + testPositiveCase(read, write, trigger) + } else { + // Invalid - trigger is continuous but reader is not + testNegativeCase( + read, write, trigger, s"Data source $read does not support continuous processing") + } - // Invalid - trigger is microbatch but reader is not - case (r, _, t) if !r.isInstanceOf[SupportsMicroBatchRead] && - !t.isInstanceOf[ContinuousTrigger] => - testPostCreationNegativeCase(read, write, trigger, - s"Data source $read does not support microbatch processing") + case microBatchTrigger => + if (sourceTable.supports(MICRO_BATCH_READ)) { + // Valid continuous queries. + testPositiveCase(read, write, trigger) + } else { + // Invalid - trigger is microbatch but reader is not + testPostCreationNegativeCase(read, write, trigger, + s"Data source $read does not support microbatch processing") + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 74ea0bfacba54..99dc0769a3d69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -359,7 +359,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { test("source metadataPath") { LastOptions.clear() - val checkpointLocationURI = new Path(newMetadataDir).toUri + val checkpointLocation = new Path(newMetadataDir) val df1 = spark.readStream .format("org.apache.spark.sql.streaming.test") @@ -371,7 +371,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { val q = df1.union(df2).writeStream .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", checkpointLocationURI.toString) + .option("checkpointLocation", checkpointLocation.toString) .trigger(ProcessingTime(10.seconds)) .start() q.processAllAvailable() @@ -379,14 +379,14 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { verify(LastOptions.mockStreamSourceProvider).createSource( any(), - meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/0"), + meq(s"${new Path(makeQualifiedPath(checkpointLocation.toString)).toString}/sources/0"), meq(None), meq("org.apache.spark.sql.streaming.test"), meq(Map.empty)) verify(LastOptions.mockStreamSourceProvider).createSource( any(), - meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/1"), + meq(s"${new Path(makeQualifiedPath(checkpointLocation.toString)).toString}/sources/1"), meq(None), meq("org.apache.spark.sql.streaming.test"), meq(Map.empty)) @@ -614,6 +614,21 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } + test("configured checkpoint dir should not be deleted if a query is stopped without errors and" + + " force temp checkpoint deletion enabled") { + import testImplicits._ + withTempDir { checkpointPath => + withSQLConf(SQLConf.CHECKPOINT_LOCATION.key -> checkpointPath.getAbsolutePath, + SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key -> "true") { + val ds = MemoryStream[Int].toDS + val query = ds.writeStream.format("console").start() + assert(checkpointPath.exists()) + query.stop() + assert(checkpointPath.exists()) + } + } + } + test("temp checkpoint dir should be deleted if a query is stopped without errors") { import testImplicits._ val query = MemoryStream[Int].toDS.writeStream.format("console").start() @@ -627,6 +642,17 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } testQuietly("temp checkpoint dir should not be deleted if a query is stopped with an error") { + testTempCheckpointWithFailedQuery(false) + } + + testQuietly("temp checkpoint should be deleted if a query is stopped with an error and force" + + " temp checkpoint deletion enabled") { + withSQLConf(SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key -> "true") { + testTempCheckpointWithFailedQuery(true) + } + } + + private def testTempCheckpointWithFailedQuery(checkpointMustBeDeleted: Boolean): Unit = { import testImplicits._ val input = MemoryStream[Int] val query = input.toDS.map(_ / 0).writeStream.format("console").start() @@ -638,7 +664,11 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { intercept[StreamingQueryException] { query.awaitTermination() } - assert(fs.exists(checkpointDir)) + if (!checkpointMustBeDeleted) { + assert(fs.exists(checkpointDir)) + } else { + assert(!fs.exists(checkpointDir)) + } } test("SPARK-20431: Specify a schema by using a DDL-formatted string") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index e45ab19aadbfa..a388de1970f14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -38,10 +38,15 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.util.Utils @@ -220,15 +225,75 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } test("save mode") { - val df = spark.read + spark.range(10).write .format("org.apache.spark.sql.test") - .load() + .mode(SaveMode.ErrorIfExists) + .save() + assert(LastOptions.saveMode === SaveMode.ErrorIfExists) - df.write + spark.range(10).write + .format("org.apache.spark.sql.test") + .mode(SaveMode.Append) + .save() + assert(LastOptions.saveMode === SaveMode.Append) + + // By default the save mode is `ErrorIfExists` for data source v1. + spark.range(10).write .format("org.apache.spark.sql.test") - .mode(SaveMode.ErrorIfExists) .save() assert(LastOptions.saveMode === SaveMode.ErrorIfExists) + + spark.range(10).write + .format("org.apache.spark.sql.test") + .mode("default") + .save() + assert(LastOptions.saveMode === SaveMode.ErrorIfExists) + } + + test("save mode for data source v2") { + var plan: LogicalPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.analyzed + + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + + spark.listenerManager.register(listener) + try { + // append mode creates `AppendData` + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode(SaveMode.Append) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + + // overwrite mode creates `OverwriteByExpression` + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode(SaveMode.Overwrite) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[OverwriteByExpression]) + + // By default the save mode is `ErrorIfExists` for data source v2. + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode("default") + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + } finally { + spark.listenerManager.unregister(listener) + } } test("test path option in load") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 132b0e4db0d71..84e5fae79bf16 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} @@ -72,6 +73,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallbackOrcDataSourceV2(session) +: + DataSourceResolution(conf, session.catalog(_)) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = @@ -86,6 +88,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: PreReadCheck +: + V2WriteSupportCheck +: + V2StreamingScanSupportCheck +: customCheckRules } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 66426824573c6..a4587abbf389d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -172,7 +172,7 @@ abstract class HiveComparisonTest // and does not return it as a query answer. case _: SetCommand => Seq("0") case _: ExplainCommand => answer - case _: DescribeTableCommand | ShowColumnsCommand(_, _) => + case _: DescribeCommandBase | ShowColumnsCommand(_, _) => // Filter out non-deterministic lines and lines which do not have actual results but // can introduce problems because of the way Hive formats these lines. // Then, remove empty lines. Do not sort the results. @@ -375,7 +375,7 @@ abstract class HiveComparisonTest if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && (!hiveQuery.logical.isInstanceOf[ShowFunctionsCommand]) && (!hiveQuery.logical.isInstanceOf[DescribeFunctionCommand]) && - (!hiveQuery.logical.isInstanceOf[DescribeTableCommand]) && + (!hiveQuery.logical.isInstanceOf[DescribeCommandBase]) && preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive