diff --git a/solr/contrib/ltr/README.txt b/solr/contrib/ltr/README.txt new file mode 100644 index 000000000000..5ab861f4a229 --- /dev/null +++ b/solr/contrib/ltr/README.txt @@ -0,0 +1,330 @@ +Apache Solr Learning to Rank +======== + +This is the main [learning to rank integrated into solr](http://www.slideshare.net/lucidworks/learning-to-rank-in-solr-presented-by-michael-nilsson-diego-ceccarelli-bloomberg-lp) +repository. +[Read up on learning to rank](https://en.wikipedia.org/wiki/Learning_to_rank) + +Apache Solr Learning to Rank (LTR) provides a way for you to extract features +directly inside Solr for use in training a machine learned model. You can then +deploy that model to Solr and use it to rerank your top X search results. + + +# Changes to solrconfig.xml +```xml + + ... + + + + + + + + + + + + + explicit + json + true + id + + + + ltrComponent + + + + + ... + + + + + + + +``` + + +# Build the plugin +In the solr/contrib/ltr directory run +`ant dist` + +# Install the plugin +In your solr installation, navigate to your collection's lib directory. +In the solr install example, it would be solr/collection1/lib. +If lib doesn't exist you will have to make it, and then copy the plugin's jar there. + +`cp lucene-solr/solr/dist/solr-ltr-X.Y.Z-SNAPSHOT.jar mySolrInstallPath/solr/myCollection/lib` + +Restart your collection using the admin page and you are good to go. +You can find more detailed instructions [here](https://wiki.apache.org/solr/SolrPlugins). + + +# Defining Features +In the learning to rank plugin, you can define features in a feature space +using standard Solr queries. As an example: + +###### features.json +```json +[ +{ "name": "isBook", + "type": "org.apache.solr.ltr.feature.impl.SolrFeature", + "params":{ "fq": ["{!terms f=category}book"] } +}, +{ + "name": "documentRecency", + "type": "org.apache.solr.ltr.feature.impl.SolrFeature", + "params": { + "q": "{!func}recip( ms(NOW,publish_date), 3.16e-11, 1, 1)" + } +}, +{ + "name":"originalScore", + "type":"org.apache.solr.ltr.feature.impl.OriginalScoreFeature", + "params":{} +}, +{ + "name" : "userTextTitleMatch", + "type" : "org.apache.solr.ltr.feature.impl.SolrFeature", + "params" : { "q" : "{!field f=title}${user_text}" } +} +] +``` + +Defines four features. Anything that is a valid Solr query can be used to define +a feature. + +### Filter Query Features +The first feature isBook fires if the term 'book' matches the category field +for the given examined document. Since in this feature q was not specified, +either the score 1 (in case of a match) or the score 0 (in case of no match) +will be returned. + +### Query Features +In the second feature (documentRecency) q was specified using a function query. +In this case the score for the feature on a given document is whatever the query +returns (1 for docs dated now, 1/2 for docs dated 1 year ago, 1/3 for docs dated +2 years ago, etc..) . If both an fq and q is used, documents that don't match +the fq will receive a score of 0 for the documentRecency feature, all other +documents will receive the score specified by the query for this feature. + +### Original Score Feature +The third feature (originalScore) has no parameters, and uses the +OriginalScoreFeature class instead of the SolrFeature class. Its purpose is +to simply return the score for the original search request against the current +matching document. + +### External Features +Users can specify external information that can to be passed in as +part of the query to the ltr ranking framework. In this case, the +fourth feature (userTextPhraseMatch) will be looking for an external field +called 'user_text' passed in through the request, and will fire if there is +a term match for the document field 'title' from the value of the external +field 'user_text'. See the "Run a Rerank Query" section for how +to pass in external information. + +### Custom Features +Custom features can be created by extending from +org.apache.solr.ltr.ranking.Feature, however this is generally not recommended. +The majority of features should be possible to create using the methods described +above. + +# Defining Models +Currently the Learning to Rank plugin supports 2 main types of +ranking models: [Ranking SVM](http://www.cs.cornell.edu/people/tj/publications/joachims_02c.pdf) +and [LambdaMART](http://research.microsoft.com/pubs/132652/MSR-TR-2010-82.pdf) + +### Ranking SVM +Currently only a linear ranking svm is supported. Use LambdaMART for +a non-linear model. If you'd like to introduce a bias set a constant feature +to the bias value you'd like and make a weight of 1.0 for that feature. + +###### model.json +```json +{ + "type":"org.apache.solr.ltr.ranking.RankSVMModel", + "name":"myModelName", + "features":[ + { "name": "userTextTitleMatch"}, + { "name": "originalScore"}, + { "name": "isBook"} + ], + "params":{ + "weights": { + "userTextTitleMatch": 1.0, + "originalScore": 0.5, + "isBook": 0.1 + } + + } +} +``` + +This is an example of a toy Ranking SVM model. Type specifies the class to be +using to interpret the model (RankSVMModel in the case of Ranking SVM). +Name is the model identifier you will use when making request to the ltr +framework. Features specifies the feature space that you want extracted +when using this model. All features that appear in the model params will +be used for scoring and must appear in the features list. You can add +extra features to the features list that will be computed but not used in the +model for scoring, which can be useful for logging. +Params are the Ranking SVM parameters. + +Good library for training SVM's (https://www.csie.ntu.edu.tw/~cjlin/liblinear/ , +https://www.csie.ntu.edu.tw/~cjlin/libsvm/) . You will need to convert the +libSVM model format to the format specified above. + +### LambdaMART + +###### model2.json +```json +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel", + "features":[ + { "name": "userTextTitleMatch"}, + { "name": "originalScore"} + ], + "params":{ + "trees": [ + { + "weight" : 1, + "tree": { + "feature": "userTextTitleMatch", + "threshold": 0.5, + "left" : { + "value" : -100 + }, + "right": { + "feature" : "originalScore", + "threshold": 10.0, + "left" : { + "value" : 50 + }, + "right" : { + "value" : 75 + } + } + } + }, + { + "weight" : 2, + "tree": { + "value" : -10 + } + } + ] + } +} +``` +This is an example of a toy LambdaMART. Type specifies the class to be using to +interpret the model (LambdaMARTModel in the case of LambdaMART). Name is the +model identifier you will use when making request to the ltr framework. +Features specifies the feature space that you want extracted when using this +model. All features that appear in the model params will be used for scoring and +must appear in the features list. You can add extra features to the features +list that will be computed but not used in the model for scoring, which can +be useful for logging. Params are the LambdaMART specific parameters. In this +case we have 2 trees, one with 3 leaf nodes and one with 1 leaf node. + +A good library for training LambdaMART ( http://sourceforge.net/p/lemur/wiki/RankLib/ ). +You will need to convert the RankLib model format to the format specified above. + +# Deploy Models and Features +To send features run + +`curl -XPUT 'http://localhost:8983/solr/collection1/schema/fstore' --data-binary @/path/features.json -H 'Content-type:application/json'` + +To send models run + +`curl -XPUT 'http://localhost:8983/solr/collection1/schema/mstore' --data-binary @/path/model.json -H 'Content-type:application/json'` + + +# View Models and Features +`curl -XGET 'http://localhost:8983/solr/collection1/schema/fstore'` +`curl -XGET 'http://localhost:8983/solr/collection1/schema/mstore'` + + +# Run a Rerank Query +Add to your original solr query +`rq={!ltr model=myModelName reRankDocs=25}` + +The model name is the name of the model you sent to solr earlier. +The number of documents you want reranked, which can be larger than the +number you display, is reRankDocs. + +### Pass in external information for external features +Add to your original solr query +`rq={!ltr reRankDocs=3 model=externalmodel efi.field1='text1' efi.field2='text2'}` + +Where "field1" specifies the name of the customized field to be used by one +or more of your features, and text1 is the information to be pass in. As an +example that matches the earlier shown userTextTitleMatch feature one could do: + +`rq={!ltr reRankDocs=3 model=externalmodel efi.user_text='Casablanca' efi.user_intent='movie'}` + +# Extract features +To extract features you need to use the feature vector transformer + set the +fv parameter to true (this required parameter will be removed in the future). +For now you need to also use a dummy model with all the features you want to +extract inside the features parameter list of the model (this limitation will +also be changed in the future so you can extract features without a dummy model). + +`fv=true&fl=*,score,[features]&rq={!ltr model=dummyModel reRankDocs=25}` + +## Test the plugin with solr/example/techproducts in 6 steps + +Solr provides some simple example of indices. In order to test the plugin with +the techproducts example please follow these steps + +1. compile solr and the examples + + cd solr + ant dist + ant example + +2. run the example + + ./bin/solr -e techproducts + +3. stop it and install the plugin: + + ./bin/solr stop + #create the lib folder + mkdir example/techproducts/solr/techproducts/lib + # install the plugin in the lib folder + cp build/contrib/ltr/lucene-ltr-6.0.0-SNAPSHOT.jar example/techproducts/solr/techproducts/lib/ + # replace the original solrconfig with one importing all the ltr componenet + cp contrib/ltr/example/solrconfig.xml example/techproducts/solr/techproducts/conf/ + +4. run the example again + + ./bin/solr -e techproducts + +5. index some features and a model + + curl -XPUT 'http://localhost:8983/solr/techproducts/schema/fstore' --data-binary "@./contrib/ltr/example/techproducts-features.json" -H 'Content-type:application/json' + curl -XPUT 'http://localhost:8983/solr/techproducts/schema/mstore' --data-binary "@./contrib/ltr/example/techproducts-model.json" -H 'Content-type:application/json' + +6. have fun ! + + # access to the default feature store + http://localhost:8983/solr/techproducts/schema/fstore/_DEFAULT_ + # access to the model store + http://localhost:8983/solr/techproducts/schema/mstore + # perform a query using the model, and retrieve the features + http://localhost:8983/solr/techproducts/query?indent=on&q=test&wt=json&rq={!ltr%20model=svm%20reRankDocs=25%20efi.query=%27test%27}&fl=*,[features],price,score,name&fv=true diff --git a/solr/contrib/ltr/build.xml b/solr/contrib/ltr/build.xml new file mode 100644 index 000000000000..d9598cc28234 --- /dev/null +++ b/solr/contrib/ltr/build.xml @@ -0,0 +1,27 @@ + + + + + + + + Learning to Rank Package + + + + diff --git a/solr/contrib/ltr/example/solrconfig.xml b/solr/contrib/ltr/example/solrconfig.xml new file mode 100644 index 000000000000..cfe060d703cd --- /dev/null +++ b/solr/contrib/ltr/example/solrconfig.xml @@ -0,0 +1,1743 @@ + + + + + + + + + 6.0.0 + + + + + + + + + + + + + + + + + + + + + + + ${solr.data.dir:} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ${solr.lock.type:native} + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + ${solr.ulog.dir:} + ${solr.ulog.numVersionBuckets:65536} + + + + + ${solr.autoCommit.maxTime:15000} + false + + + + + + ${solr.autoSoftCommit.maxTime:-1} + + + + + + + + + + + + + + + + 1024 + + + + -1 + + + + + + + + + + + + + + + + + + + + + + + + + + + true + + + + + + 20 + + + 200 + + + + + + + + + + + + static firstSearcher warming in solrconfig.xml + + + + + + false + + + 2 + + + + + + + + + + + + + + + + + + + + + + + explicit + 10 + + false + + + + + + + + + + + + + + + + + + + + + + + explicit + json + true + text + + + + ltrComponent + + + + + + + + explicit + + + velocity + browse + layout + Solritas + + + edismax + + text^0.5 features^1.0 name^1.2 sku^1.5 id^10.0 manu^1.1 cat^1.4 + title^10.0 description^5.0 keywords^5.0 author^2.0 resourcename^1.0 + + 100% + *:* + 10 + *,score + + + text^0.5 features^1.0 name^1.2 sku^1.5 id^10.0 manu^1.1 cat^1.4 + title^10.0 description^5.0 keywords^5.0 author^2.0 resourcename^1.0 + + text,features,name,sku,id,manu,cat,title,description,keywords,author,resourcename + 3 + + + on + true + cat + manu_exact + content_type + author_s + ipod + GB + 1 + cat,inStock + after + price + 0 + 600 + 50 + popularity + 0 + 10 + 3 + manufacturedate_dt + NOW/YEAR-10YEARS + NOW + +1YEAR + before + after + + + on + content features title name + true + html + <b> + </b> + 0 + title + 0 + name + 3 + 200 + content + 750 + + + on + false + 5 + 2 + 5 + true + true + 5 + 3 + + + + + spellcheck + + + + + + + text + + + + + + + _src_ + + true + + + + + + + + + + true + ignored_ + + + true + links + ignored_ + + + + + + + + + + + + + + + explicit + true + + + + + + + + + text_general + + + + + + default + text + solr.DirectSolrSpellChecker + + internal + + 0.5 + + 2 + + 1 + + 5 + + 4 + + 0.01 + + + + + + wordbreak + solr.WordBreakSolrSpellChecker + name + true + true + 10 + + + + + + + + + + + + + + + + + default + wordbreak + on + true + 10 + 5 + 5 + true + true + 10 + 5 + + + spellcheck + + + + + + + mySuggester + FuzzyLookupFactory + DocumentDictionaryFactory + cat + price + string + false + + + + + + true + 10 + + + suggest + + + + + + + + + + + true + + + tvComponent + + + + + + + + + lingo3g + true + com.carrotsearch.lingo3g.Lingo3GClusteringAlgorithm + clustering/carrot2 + + + + lingo + org.carrot2.clustering.lingo.LingoClusteringAlgorithm + clustering/carrot2 + + + + stc + org.carrot2.clustering.stc.STCClusteringAlgorithm + clustering/carrot2 + + + + kmeans + org.carrot2.clustering.kmeans.BisectingKMeansClusteringAlgorithm + clustering/carrot2 + + + + + + + true + true + + name + + id + + features + + true + + + + false + + + edismax + + text^0.5 features^1.0 name^1.2 sku^1.5 id^10.0 manu^1.1 cat^1.4 + + *:* + 100 + *,score + + + clustering + + + + + + + + + + true + false + + + terms + + + + + + + + string + elevate.xml + + + + + + explicit + + + elevator + + + + + + + + + + + 100 + + + + + + + + 70 + + 0.5 + + [-\w ,/\n\"']{20,200} + + + + + + + ]]> + ]]> + + + + + + + + + + + + + + + + + + + + + + + + ,, + ,, + ,, + ,, + ,]]> + ]]> + + + + + + 10 + .,!? + + + + + + + WORD + + + en + US + + + + + + + + + + + + + + + + + + + + + + text/plain; charset=UTF-8 + + + + + ${velocity.template.base.dir:} + + + + + + 5 + + + + + + + + + + + + + + + + + + *:* + + + diff --git a/solr/contrib/ltr/example/techproducts-features.json b/solr/contrib/ltr/example/techproducts-features.json new file mode 100644 index 000000000000..142cb75d2a51 --- /dev/null +++ b/solr/contrib/ltr/example/techproducts-features.json @@ -0,0 +1,26 @@ +[ +{ + "name": "isInStock", + "type": "org.apache.solr.ltr.feature.impl.FieldValueFeature", + "params": { + "field": "inStock" + } +}, +{ + "name": "price", + "type": "org.apache.solr.ltr.feature.impl.FieldValueFeature", + "params": { + "field": "price" + } +}, +{ + "name":"originalScore", + "type":"org.apache.solr.ltr.feature.impl.OriginalScoreFeature", + "params":{} +}, +{ + "name" : "productNameMatchQuery", + "type" : "org.apache.solr.ltr.feature.impl.SolrFeature", + "params" : { "q" : "{!field f=name}${query}" } +} +] diff --git a/solr/contrib/ltr/example/techproducts-model.json b/solr/contrib/ltr/example/techproducts-model.json new file mode 100644 index 000000000000..38b0342197c0 --- /dev/null +++ b/solr/contrib/ltr/example/techproducts-model.json @@ -0,0 +1,18 @@ +{ + "type":"org.apache.solr.ltr.ranking.RankSVMModel", + "name":"svm", + "features":[ + {"name":"isInStock"}, + {"name":"price"}, + {"name":"originalScore"}, + {"name":"productNameMatchQuery"} + ], + "params":{ + "weights":{ + "isInStock":15.0, + "price":1.0, + "originalScore":5.0, + "productNameMatchQuery":1.0 + } + } +} diff --git a/solr/contrib/ltr/ivy.xml b/solr/contrib/ltr/ivy.xml new file mode 100644 index 000000000000..68e9797bb09a --- /dev/null +++ b/solr/contrib/ltr/ivy.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureStore.java new file mode 100644 index 000000000000..20779c541924 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureStore.java @@ -0,0 +1,80 @@ +package org.apache.solr.ltr.feature; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.util.FeatureException; + +public class FeatureStore { + LinkedHashMap store = new LinkedHashMap<>(); + String storeName; + + public FeatureStore(String storeName) { + this.storeName = storeName; + } + + public Feature get(String name) throws FeatureException { + if (!store.containsKey(name)) { + throw new FeatureException("missing feature " + name + + ". Store name was: '" + storeName + + "'. Possibly this feature exists in another context."); + } + return store.get(name); + } + + public int size() { + return store.size(); + } + + public boolean containsFeature(String name) { + return store.containsKey(name); + } + + public List featuresAsManagedResources() { + List features = new ArrayList(); + for (Feature f : store.values()) { + Map o = new LinkedHashMap<>(); + o.put("name", f.getName()); + o.put("type", f.getType()); + o.put("store", storeName); + o.put("params", f.getParams()); + features.add(o); + } + return features; + } + + public void add(Feature feature) { + store.put(feature.getName(), feature); + } + + public Collection getFeatures() { + return store.values(); + } + + public void clear() { + store.clear(); + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ModelMetadata.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ModelMetadata.java new file mode 100644 index 000000000000..c7534e2d6f04 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ModelMetadata.java @@ -0,0 +1,149 @@ +package org.apache.solr.ltr.feature; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.Collection; +import java.util.List; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.util.NamedParams; + +/** + * Contains all the data needed for loading a model. + */ +// FIXME: Rename to something like RankingAlgorithm or ScoringAlgorithm +public abstract class ModelMetadata { + + private String name; + private String type; + private String featureStoreName; + private List features; + private Collection allFeatures; + private NamedParams params; + + public ModelMetadata(String name, String type, List features, + String featureStoreName, Collection allFeatures, + NamedParams params) { + this.name = name; + this.type = type; + this.features = features; + this.featureStoreName = featureStoreName; + this.allFeatures = allFeatures; + this.params = params; + } + + /** + * @return the name + */ + public String getName() { + return name; + } + + /** + * @return the type + */ + public String getType() { + return type; + } + + /** + * @return the features + */ + public List getFeatures() { + return features; + } + + public NamedParams getParams() { + return params; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((features == null) ? 0 : features.hashCode()); + result = prime * result + ((name == null) ? 0 : name.hashCode()); + result = prime * result + ((params == null) ? 0 : params.hashCode()); + result = prime * result + ((type == null) ? 0 : type.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + ModelMetadata other = (ModelMetadata) obj; + if (features == null) { + if (other.features != null) return false; + } else if (!features.equals(other.features)) return false; + if (name == null) { + if (other.name != null) return false; + } else if (!name.equals(other.name)) return false; + if (params == null) { + if (other.params != null) return false; + } else if (!params.equals(other.params)) return false; + if (type == null) { + if (other.type != null) return false; + } else if (!type.equals(other.type)) return false; + return true; + } + + public boolean hasParams() { + return !(params == null || params.isEmpty()); + } + + public Collection getAllFeatures() { + return allFeatures; + } + + public String getFeatureStoreName() { + return featureStoreName; + } + + /** + * Given a list of normalized values for all features a scoring algorithm + * cares about, calculate and return a score. + * + * @param modelFeatureValuesNormalized + * List of normalized feature values. Each feature is identified by + * its id, which is the index in the array + * @return The final score for a document + */ + public abstract float score(float[] modelFeatureValuesNormalized); + + /** + * Similar to the score() function, except it returns an explanation of how + * the features were used to calculate the score. + * + * @param context + * Context the document is in + * @param doc + * Document to explain + * @param finalScore + * Original score + * @param featureExplanations + * Explanations for each feature calculation + * @return Explanation for the scoring of a doument + */ + public abstract Explanation explain(LeafReaderContext context, int doc, + float finalScore, List featureExplanations); + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ModelStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ModelStore.java new file mode 100644 index 000000000000..4743554cc21b --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ModelStore.java @@ -0,0 +1,152 @@ +package org.apache.solr.ltr.feature; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.util.ModelException; +import org.apache.solr.ltr.util.NameValidator; + +/** + * Contains the model and features declared. + */ +public class ModelStore { + + private Map availableModels; + + public ModelStore() { + availableModels = new HashMap<>(); + } + + public synchronized ModelMetadata getModel(String name) throws ModelException { + ModelMetadata model = availableModels.get(name); + if (model == null) { + throw new ModelException("cannot find model " + name); + } + return model; + + } + + public boolean containsModel(String modelName) { + return availableModels.containsKey(modelName); + } + + /** + * Returns the available models as a list of Maps objects. After an update the + * managed resources needs to return the resources in this format in order to + * store in json somewhere (zookeeper, disk...) + * + * TODO investigate if it is possible to replace the managed resources' json + * serializer/deserialiazer. + * + * @return the available models as a list of Maps objects + */ + public List modelAsManagedResources() { + List list = new ArrayList<>(); + for (ModelMetadata modelmeta : availableModels.values()) { + Map modelMap = new HashMap<>(); + modelMap.put("name", modelmeta.getName()); + modelMap.put("type", modelmeta.getType()); + modelMap.put("store", modelmeta.getFeatureStoreName()); + List> features = new ArrayList<>(); + for (Feature meta : modelmeta.getFeatures()) { + Map map = new HashMap(); + map.put("name", meta.getName()); + + Normalizer n = meta.getNorm(); + + if (n != null) { + Map normalizer = new HashMap<>(); + normalizer.put("type", n.getType()); + normalizer.put("params", n.getParams()); + map.put("norm", normalizer); + } + features.add(map); + + } + modelMap.put("features", features); + modelMap.put("params", modelmeta.getParams()); + + list.add(modelMap); + } + return list; + } + + public void clear() { + availableModels.clear(); + + } + + @Override + public String toString() { + return "ModelStore [availableModels=" + availableModels.keySet() + "]"; + } + + public void delete(String childId) { + availableModels.remove(childId); + + } + + public synchronized void addModel(ModelMetadata modeldata) + throws ModelException { + String name = modeldata.getName(); + + if (modeldata.getFeatures().isEmpty()) { + throw new ModelException("no features declared for model " + + modeldata.getName()); + } + if (!NameValidator.check(name)) { + throw new ModelException("invalid model name " + name); + } + + if (containsModel(name)) { + throw new ModelException("model '" + name + + "' already exists. Please use a different name"); + } + + String type = modeldata.getType(); + try { + Class.forName(type); + } catch (ClassNotFoundException e) { + throw new ModelException("cannot find class " + type + + " implementing model " + name, e); + } + + // checks for duplicates in the feature + Set names = new HashSet<>(); + for (Feature feature : modeldata.getFeatures()) { + String fname = feature.getName(); + if (names.contains(fname)) { + throw new ModelException("duplicated feature " + fname + " in model " + + name); + } + + names.add(fname); + } + + availableModels.put(modeldata.getName(), modeldata); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/FieldLengthFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/FieldLengthFeature.java new file mode 100644 index 000000000000..1337195be88a --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/FieldLengthFeature.java @@ -0,0 +1,147 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; + +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.util.SmallFloat; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.ranking.FeatureScorer; +import org.apache.solr.ltr.ranking.FeatureWeight; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.NamedParams; + +public class FieldLengthFeature extends Feature { + String field; + + /** Cache of decoded bytes. */ + + private static final float[] NORM_TABLE = new float[256]; + + static { + for (int i = 0; i < 256; i++) { + NORM_TABLE[i] = SmallFloat.byte315ToFloat((byte) i); + + } + + } + + /** + * Decodes the norm value, assuming it is a single byte. + * + */ + + private final float decodeNorm(long norm) { + return NORM_TABLE[(int) (norm & 0xFF)]; // & 0xFF maps negative bytes to + // positive above 127 + } + + public FieldLengthFeature() { + + } + + public void init(String name, NamedParams params, int id) + throws FeatureException { + super.init(name, params, id); + if (!params.containsKey("field")) { + throw new FeatureException("missing param field"); + } + } + + @Override + public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores) + throws IOException { + this.field = (String) params.get("field"); + return new FieldLengthFeatureWeight(searcher, name, params, norm, id); + } + + @Override + public String toString(String f) { + return "FieldLengthFeature [field:" + field + "]"; + + } + + public class FieldLengthFeatureWeight extends FeatureWeight { + + public FieldLengthFeatureWeight(IndexSearcher searcher, String name, + NamedParams params, Normalizer norm, int id) { + super(FieldLengthFeature.this, searcher, name, params, norm, id); + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + return new FieldLengthFeatureScorer(this, context); + + } + + public class FieldLengthFeatureScorer extends FeatureScorer { + + LeafReaderContext context = null; + NumericDocValues norms = null; + DocIdSetIterator itr; + + public FieldLengthFeatureScorer(FeatureWeight weight, + LeafReaderContext context) throws IOException { + super(weight); + this.context = context; + this.itr = new MatchAllIterator(); + norms = context.reader().getNormValues(field); + + // In the constructor, docId is -1, so using 0 as default lookup + IndexableField idxF = searcher.doc(0).getField(field); + if (idxF.fieldType().omitNorms()) throw new IOException( + "FieldLengthFeatures can't be used if omitNorms is enabled (field=" + + field + ")"); + + } + + @Override + public float score() throws IOException { + + long l = norms.get(itr.docID()); + float norm = decodeNorm(l); + float numTerms = (float) Math.pow(1f / norm, 2); + + return numTerms; + } + + @Override + public String toString() { + return "FieldLengthFeature [name=" + name + " field=" + field + "]"; + } + + @Override + public int docID() { + return itr.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return itr; + } + + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/FieldValueFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/FieldValueFeature.java new file mode 100644 index 000000000000..47e948010b9d --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/FieldValueFeature.java @@ -0,0 +1,143 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.Set; + +import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.ranking.FeatureScorer; +import org.apache.solr.ltr.ranking.FeatureWeight; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.NamedParams; + +import com.google.common.collect.Sets; + +public class FieldValueFeature extends Feature { + String fieldName; + Set fields = Sets.newHashSet(); + + public FieldValueFeature() { + + } + + public void init(String name, NamedParams params, int id) + throws FeatureException { + super.init(name, params, id); + if (!params.containsKey("field")) { + throw new FeatureException("missing param field"); + } + } + + @Override + public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores) + throws IOException { + this.fieldName = (String) params.get("field"); + fields.add(this.fieldName); + return new FieldValueFeatureWeight(searcher, name, params, norm, id); + } + + @Override + public String toString(String f) { + return "FieldValueFeature [field:" + fieldName + "]"; + + } + + public class FieldValueFeatureWeight extends FeatureWeight { + + public FieldValueFeatureWeight(IndexSearcher searcher, String name, + NamedParams params, Normalizer norm, int id) { + super(FieldValueFeature.this, searcher, name, params, norm, id); + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + return new FieldValueFeatureScorer(this, context); + } + + public class FieldValueFeatureScorer extends FeatureScorer { + + LeafReaderContext context = null; + DocIdSetIterator itr; + + public FieldValueFeatureScorer(FeatureWeight weight, + LeafReaderContext context) { + super(weight); + this.context = context; + this.itr = new MatchAllIterator(); + } + + @Override + public float score() throws IOException { + + try { + Document document = context.reader().document(itr.docID(), fields); + IndexableField field = document.getField(fieldName); + if (field == null) { + // logger.debug("no field {}", f); + // TODO define default value + return 0; + } + Number number = field.numericValue(); + if (number != null) { + return number.floatValue(); + } else { + String string = field.stringValue(); + // boolean values in the index are encoded with the + // chars T/F + if (string.equals("T")) { + return 1; + } + if (string.equals("F")) { + return 0; + } + } + } catch (IOException e) { + // TODO discuss about about feature failures: + // do we want to return a default value? + // do we want to fail? + } + // TODO define default value + return 0; + } + + @Override + public String toString() { + return "FieldValueFeature [name=" + name + " fields=" + fields + "]"; + } + + @Override + public int docID() { + return itr.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return itr; + } + + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/MatchAllIterator.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/MatchAllIterator.java new file mode 100644 index 000000000000..f9005abc7886 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/MatchAllIterator.java @@ -0,0 +1,52 @@ +package org.apache.solr.ltr.feature.impl; + +import java.io.IOException; + +import org.apache.lucene.search.DocIdSetIterator; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +public class MatchAllIterator extends DocIdSetIterator { + protected int docID = -1; + + @Override + public int docID() { + return docID; + } + + @Override + public int nextDoc() throws IOException { + // only the rescorer will call this scorers and nextDoc will never be called + return ++docID; // FIXME: Keep this or throw new + // UnsupportedOperationException()? + } + + @Override + public int advance(int target) throws IOException { + // For advanced features that use Solr scorers internally, you must override + // and pass this call on to them + docID = target; + return docID; + } + + @Override + public long cost() { + return 0; // FIXME: Do something here + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/OriginalScoreFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/OriginalScoreFeature.java new file mode 100644 index 000000000000..5734c7a6bd6c --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/OriginalScoreFeature.java @@ -0,0 +1,114 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.ranking.FeatureScorer; +import org.apache.solr.ltr.ranking.FeatureWeight; +import org.apache.solr.ltr.ranking.LTRRescorer; +import org.apache.solr.ltr.util.NamedParams; + +public class OriginalScoreFeature extends Feature { + + @Override + public OriginalScoreWeight createWeight(IndexSearcher searcher, + boolean needsScores) throws IOException { + return new OriginalScoreWeight(searcher, name, params, norm, id); + + } + + public class OriginalScoreWeight extends FeatureWeight { + + Weight w = null; + + public OriginalScoreWeight(IndexSearcher searcher, String name, + NamedParams params, Normalizer norm, int id) { + super(OriginalScoreFeature.this, searcher, name, params, norm, id); + + } + + public void process() throws IOException { + // I can't set w before in the constructor because I would need to have it + // in the query for doing that. But the query/feature is shared among + // different threads so I can't set the original query there. + w = searcher.createNormalizedWeight(this.originalQuery, true); + }; + + @Override + public Explanation explain(LeafReaderContext context, int doc) + throws IOException { + // Explanation e = w.explain(context, doc); + Scorer s = w.scorer(context); + s.iterator().advance(doc); + float score = s.score(); + return Explanation.match(score, "original score query: " + originalQuery); + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + + Scorer originalScorer = w.scorer(context); + return new OriginalScoreScorer(this, originalScorer); + } + + public class OriginalScoreScorer extends FeatureScorer { + Scorer originalScorer; + + public OriginalScoreScorer(FeatureWeight weight, Scorer originalScorer) { + super(weight); + this.originalScorer = originalScorer; + } + + @Override + public float score() throws IOException { + // This is done to improve the speed of feature extraction. Since this + // was already scored in step 1 + // we shouldn't need to calc original score again. + return this.hasDocParam(LTRRescorer.ORIGINAL_DOC_NAME) ? (Float) this + .getDocParam(LTRRescorer.ORIGINAL_DOC_NAME) : originalScorer + .score(); + } + + @Override + public String toString() { + return "OriginalScoreFeature [query:" + originalQuery.toString() + "]"; + } + + @Override + public int docID() { + return originalScorer.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return originalScorer.iterator(); + } + } + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/SolrFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/SolrFeature.java new file mode 100644 index 000000000000..ed3af87ccc4c --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/SolrFeature.java @@ -0,0 +1,305 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSet; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.apache.solr.common.params.CommonParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrCore; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.ranking.FeatureScorer; +import org.apache.solr.ltr.ranking.FeatureWeight; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.request.LocalSolrQueryRequest; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.search.QParser; +import org.apache.solr.search.SolrIndexSearcher; +import org.apache.solr.search.SolrIndexSearcher.ProcessedFilter; + +public class SolrFeature extends Feature { + + @Override + public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores) + throws IOException { + return new SolrFeatureWeight(searcher, name, params, norm, id); + } + + public class SolrFeatureWeight extends FeatureWeight { + Weight solrQueryWeight; + Query query; + List queryAndFilters; + + public SolrFeatureWeight(IndexSearcher searcher, String name, + NamedParams params, Normalizer norm, int id) throws IOException { + super(SolrFeature.this, searcher, name, params, norm, id); + } + + @Override + public void process() throws FeatureException { + try { + String df = (String) getParams().get(CommonParams.DF); + String defaultParser = (String) getParams().get("defaultParser"); + String solrQuery = (String) getParams().get(CommonParams.Q); + List fqs = (List) getParams().get(CommonParams.FQ); + + if ((solrQuery == null || solrQuery.isEmpty()) + && (fqs == null || fqs.isEmpty())) { + throw new IOException("ERROR: FQ or Q have not been provided"); + } + + if (solrQuery == null || solrQuery.isEmpty()) { + solrQuery = "*:*"; + } + solrQuery = macroExpander.expand(solrQuery); + + SolrQueryRequest req = makeRequest(request.getCore(), solrQuery, fqs, + df); + if (req == null) { + throw new IOException("ERROR: No parameters provided"); + } + + // Build the filter queries + this.queryAndFilters = new ArrayList(); // If there are no fqs we + // just want an empty + // list + if (fqs != null) { + for (String fq : fqs) { + if (fq != null && fq.trim().length() != 0) { + fq = macroExpander.expand(fq); + QParser fqp = QParser.getParser(fq, null, req); + Query filterQuery = fqp.getQuery(); + if (filterQuery != null) { + queryAndFilters.add(filterQuery); + } + } + } + } + + QParser parser = QParser.getParser(solrQuery, + defaultParser == null ? "lucene" : defaultParser, req); + query = parser.parse(); + + // Query can be null if there was no input to parse, for instance if you + // make a phrase query with "to be", and the analyzer removes all the + // words + // leaving nothing for the phrase query to parse. + if (query != null) { + queryAndFilters.add(query); + solrQueryWeight = searcher.createNormalizedWeight(query, true); + } + + } catch (Exception e) { + throw new FeatureException("Exception for " + this.toString() + " " + + e.getMessage(), e); + } + } + + private LocalSolrQueryRequest makeRequest(SolrCore core, String solrQuery, + List fqs, String df) { + // Map.Entry [] entries = new NamedListEntry[q.length / + // 2]; + NamedList returnList = new NamedList(); + if (solrQuery != null && !solrQuery.isEmpty()) { + returnList.add(CommonParams.Q, solrQuery); + } + if (fqs != null) { + for (String fq : fqs) { + returnList.add(CommonParams.FQ, fq); + // entries[i/2] = new NamedListEntry<>(q[i], q[i+1]); + } + } + if (df != null && !df.isEmpty()) { + returnList.add(CommonParams.DF, df); + } + if (returnList.size() > 0) return new LocalSolrQueryRequest(core, + returnList); + else return null; + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + Scorer solrScorer = null; + if (solrQueryWeight != null) { + solrScorer = solrQueryWeight.scorer(context); + } + + DocIdSetIterator idItr = getDocIdSetIteratorFromQueries(queryAndFilters, + context); + if (idItr != null) { + return solrScorer == null ? new SolrFeatureFilterOnlyScorer(this, idItr) + : new SolrFeatureScorer(this, solrScorer, idItr); + } else { + return null; + } + } + + /** + * Given a list of Solr filters/queries, return a doc iterator that + * traverses over the documents that matched all the criteria of the + * queries. + * + * @param queries + * Filtering criteria to match documents against + * @param context + * Index reader + * @return DocIdSetIterator to traverse documents that matched all filter + * criteria + */ + public DocIdSetIterator getDocIdSetIteratorFromQueries(List queries, + LeafReaderContext context) throws IOException { + // FIXME: Only SolrIndexSearcher has getProcessedFilter(), but all weights + // are given an IndexSearcher instead. + // Ideally there should be some guarantee that we have a SolrIndexSearcher + // so we don't have to cast. + ProcessedFilter pf = ((SolrIndexSearcher) searcher).getProcessedFilter( + null, queries); + final Bits liveDocs = context.reader().getLiveDocs(); + + DocIdSetIterator idIter = null; + if (pf.filter != null) { + DocIdSet idSet = pf.filter.getDocIdSet(context, liveDocs); + if (idSet != null) idIter = idSet.iterator(); + } + + return idIter; + } + + public class SolrFeatureScorer extends FeatureScorer { + Scorer solrScorer; + String q; + DocIdSetIterator itr; + + public SolrFeatureScorer(FeatureWeight weight, Scorer solrScorer, + DocIdSetIterator filterIterator) { + super(weight); + q = (String) getParams().get(CommonParams.Q); + this.solrScorer = solrScorer; + this.itr = new SolrFeatureScorerIterator(filterIterator, + solrScorer.iterator()); + } + + @Override + public float score() throws IOException { + return solrScorer.score(); + } + + @Override + public String toString() { + return "SolrFeature [function:" + q + "]"; + } + + @Override + public DocIdSetIterator iterator() { + return itr; + } + + @Override + public int docID() { + return itr.docID(); + } + + private class SolrFeatureScorerIterator extends DocIdSetIterator { + + DocIdSetIterator filterIterator; + DocIdSetIterator scorerFilter; + int docID; + + SolrFeatureScorerIterator(DocIdSetIterator filterIterator, + DocIdSetIterator scorerFilter) { + this.filterIterator = filterIterator; + this.scorerFilter = scorerFilter; + } + + @Override + public int docID() { + return filterIterator.docID(); + } + + @Override + public int nextDoc() throws IOException { + docID = filterIterator.nextDoc(); + scorerFilter.advance(docID); + return docID; + } + + @Override + public int advance(int target) throws IOException { + // We use iterator to catch the scorer up since + // that checks if the target id is in the query + all the filters + docID = filterIterator.advance(target); + scorerFilter.advance(docID); + return docID; + } + + @Override + public long cost() { + return 0; // FIXME: Make this work? + } + + } + } + + public class SolrFeatureFilterOnlyScorer extends FeatureScorer { + String fq; + DocIdSetIterator itr; + + public SolrFeatureFilterOnlyScorer(FeatureWeight weight, + DocIdSetIterator iterator) { + super(weight); + fq = (String) getParams().get(CommonParams.FQ); + this.itr = iterator; + } + + @Override + public float score() throws IOException { + return 1f; + } + + @Override + public String toString() { + return "SolrFeature [function:" + fq + "]"; + } + + @Override + public DocIdSetIterator iterator() { + return itr; + } + + @Override + public int docID() { + return itr.docID(); + } + + } + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/ValueFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/ValueFeature.java new file mode 100644 index 000000000000..ee9b07f484b7 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/ValueFeature.java @@ -0,0 +1,141 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.ranking.FeatureScorer; +import org.apache.solr.ltr.ranking.FeatureWeight; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.NamedParams; + +public class ValueFeature extends Feature { + + protected float configValue = -1f; + protected String configValueStr = null; + + public ValueFeature() {} + + @Override + public void init(String name, NamedParams params, int id) + throws FeatureException { + super.init(name, params, id); + Object paramValue = params.get("value"); + if (paramValue == null) { + throw new FeatureException("Missing the field 'value' in params for " + + this); + } + + if (paramValue instanceof String) { + this.configValueStr = (String) paramValue; + if (this.configValueStr.trim().isEmpty()) { + throw new FeatureException("Empty field 'value' in params for " + this); + } + } else { + try { + this.configValue = NamedParams.convertToFloat(paramValue); + } catch (NumberFormatException e) { + throw new FeatureException("Invalid type for 'value' in params for " + + this); + } + } + } + + @Override + public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores) + throws IOException { + return new ValueFeatureWeight(searcher, name, params, norm, id); + } + + public class ValueFeatureWeight extends FeatureWeight { + + protected float featureValue; + + public ValueFeatureWeight(IndexSearcher searcher, String name, + NamedParams params, Normalizer norm, int id) { + super(ValueFeature.this, searcher, name, params, norm, id); + } + + @Override + public void process() throws IOException { + // Value replace from external feature info if applicable. Each request + // can change the + // value if it is using ${myExternalValue} for the configValueStr, + // otherwise use the + // constant value provided in the config. + if (configValueStr != null) { + featureValue = Float.parseFloat(macroExpander.expand(configValueStr)); + } else { + featureValue = configValue; + } + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + return new ValueFeatureScorer(this, featureValue, "ValueFeature"); + } + + /** + * Default FeatureScorer class that returns the score passed in. Can be used + * as a simple ValueFeature, or to return a default scorer in case an + * underlying feature's scorer is null. + */ + public class ValueFeatureScorer extends FeatureScorer { + + float constScore; + String featureType; + DocIdSetIterator itr; + + public ValueFeatureScorer(FeatureWeight weight, float constScore, + String featureType) { + super(weight); + this.constScore = constScore; + this.featureType = featureType; + this.itr = new MatchAllIterator(); + } + + @Override + public float score() { + return constScore; + } + + @Override + public String toString() { + return featureType + " [name=" + name + " value=" + constScore + "]"; + } + + @Override + public int docID() { + return itr.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return itr; + } + + } + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/package-info.java new file mode 100644 index 000000000000..fda5a517060c --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/impl/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Implementations of the Features. + */ +package org.apache.solr.ltr.feature.impl; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/Normalizer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/Normalizer.java new file mode 100644 index 000000000000..36f788f0eb5b --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/Normalizer.java @@ -0,0 +1,66 @@ +package org.apache.solr.ltr.feature.norm; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.search.Explanation; +import org.apache.solr.ltr.feature.norm.impl.IdentityNormalizer; +import org.apache.solr.ltr.feature.norm.impl.StandardNormalizer; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.ltr.util.NormalizerException; + +/** + * A normalizer normalizes the value of a feature. Once that the feature values + * will be computed, the normalizer will be applied and the resulting values + * will be received by the model. + * + * @see IdentityNormalizer + * @see StandardNormalizer + * + */ +public abstract class Normalizer { + + protected String type = this.getClass().getCanonicalName(); + NamedParams params; + + public String getType() { + return type; + } + + public NamedParams getParams() { + return params; + } + + public void setType(String type) { + this.type = type; + } + + public void init(NamedParams params) throws NormalizerException { + this.params = params; + } + + public abstract float normalize(float value); + + public Explanation explain(Explanation explain) { + float normalized = normalize(explain.getValue()); + String explainDesc = "normalized using " + type; + if (params != null) explainDesc += " [params " + params + "]"; + + return Explanation.match(normalized, explainDesc, explain); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/IdentityNormalizer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/IdentityNormalizer.java new file mode 100644 index 000000000000..77a2a264c468 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/IdentityNormalizer.java @@ -0,0 +1,38 @@ +package org.apache.solr.ltr.feature.norm.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.util.NamedParams; + +public class IdentityNormalizer extends Normalizer { + + public static final IdentityNormalizer INSTANCE = new IdentityNormalizer(); + + public IdentityNormalizer() { + + } + + public void init(NamedParams params) {} + + @Override + public float normalize(float value) { + return value; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/MinMaxNormalizer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/MinMaxNormalizer.java new file mode 100644 index 000000000000..93e4c58d17b0 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/MinMaxNormalizer.java @@ -0,0 +1,58 @@ +package org.apache.solr.ltr.feature.norm.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.ltr.util.NormalizerException; + +public class MinMaxNormalizer extends Normalizer { + + private float min; + private float max; + private float delta; + + public void init(NamedParams params) throws NormalizerException { + super.init(params); + if (!params.containsKey("min")) throw new NormalizerException( + "missing required param [min] for normalizer MinMaxNormalizer"); + if (!params.containsKey("max")) throw new NormalizerException( + "missing required param [max] for normalizer MinMaxNormalizer"); + try { + min = (float) params.getFloat("min"); + + max = (float) params.getFloat("max"); + + } catch (Exception e) { + throw new NormalizerException( + "invalid param value for normalizer MinMaxNormalizer", e); + } + + delta = max - min; + if (delta <= 0) { + throw new NormalizerException( + "invalid param value for MinMaxNormalizer, min must be lower than max "); + } + } + + @Override + public float normalize(float value) { + return (value - min) / delta; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/StandardNormalizer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/StandardNormalizer.java new file mode 100644 index 000000000000..8c3fe4b4e5c8 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/StandardNormalizer.java @@ -0,0 +1,47 @@ +package org.apache.solr.ltr.feature.norm.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.ltr.util.NormalizerException; + +public class StandardNormalizer extends Normalizer { + + private float avg; + private float std; + + public void init(NamedParams params) throws NormalizerException { + super.init(params); + if (!params.containsKey("avg")) { + throw new NormalizerException("missing param avg"); + } + if (!params.containsKey("std")) { + throw new NormalizerException("missing param std"); + } + avg = params.getFloat("avg", 0); + std = params.getFloat("std", 1); + if (std <= 0) throw new NormalizerException("std must be > 0"); + } + + @Override + public float normalize(float value) { + return (value - avg) / std; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/package-info.java new file mode 100644 index 000000000000..a105e500aadd --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/impl/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * This package contains several implementations of normalizers. + */ +package org.apache.solr.ltr.feature.norm.impl; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/package-info.java new file mode 100644 index 000000000000..0cc911cb35af --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/norm/package-info.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * A normalizer normalizes the value of a feature. Once that the feature values + * will be computed, the normalizer will be applied and the resulting values + * will be received by the model. + */ +package org.apache.solr.ltr.feature.norm; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/package-info.java new file mode 100644 index 000000000000..1e33e2af70cb --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains implementations of Feature and Model stores + */ +package org.apache.solr.ltr.feature; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/log/FeatureLogger.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/log/FeatureLogger.java new file mode 100644 index 000000000000..e32cbab70d82 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/log/FeatureLogger.java @@ -0,0 +1,157 @@ +package org.apache.solr.ltr.log; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.lang.invoke.MethodHandles; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.solr.ltr.ranking.ModelQuery; +import org.apache.solr.search.SolrCache; +import org.apache.solr.search.SolrIndexSearcher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * FeatureLogger can be registered in a model and provide a strategy for logging + * the feature values. + */ +public abstract class FeatureLogger { + + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + public static final String QUERY_FV_CACHE_NAME = "QUERY_DOC_FV"; + + /** + * Log will be called every time that the model generates the feature values + * for a document and a query. + * + * @param docid + * Solr document id whose features we are saving + * @param featureNames + * List of all the feature names we are logging + * @param featureValues + * Parallel list to featureNames that stores all the unnormalized + * feature values + * @param featuresUsed + * Bitset indicating which featureValues to log + * + * @return true if the logger successfully logged the features, false + * otherwise. + */ + + public boolean log(int docid, ModelQuery modelQuery, + SolrIndexSearcher searcher, String[] featureNames, float[] featureValues, + boolean[] featuresUsed) { + FV_TYPE r = makeFeatureVector(featureNames, featureValues, featuresUsed); + if (r == null) return false; + // FIXME: Confirm this hashing works + return searcher.cacheInsert(QUERY_FV_CACHE_NAME, modelQuery.hashCode() + 31 + * docid, makeFeatureVector(featureNames, featureValues, featuresUsed)) != null; + } + + /** + * returns a FeatureLogger that logs the features in output, using the format + * specified in the 'format' param: 'csv' will log the features as a unique + * string in csv format 'json' will log the features in a map in a Map of + * featureName keys to featureValue values if format is null or empty, csv + * format will be selected. + * + * @return a feature logger for the format specified. + */ + public static FeatureLogger getFeatureLogger(String format) { + if (format == null || format.isEmpty()) return new CSVFeatureLogger(); + if (format.equals("csv")) return new CSVFeatureLogger(); + if (format.equals("json")) return new MapFeatureLogger(); + logger.warn("unknown feature logger {}", format); + return null; + + } + + public abstract FV_TYPE makeFeatureVector(String[] featureNames, + float[] featureValues, boolean[] featuresUsed); + + /** + * populate the document with its feature vector + * + * @param docid + * Solr document id + * @return String representation of the list of features calculated for docid + */ + public FV_TYPE getFeatureVector(int docid, ModelQuery reRankModel, + SolrIndexSearcher searcher) { + SolrCache fvCache = searcher.getCache(QUERY_FV_CACHE_NAME); + return fvCache == null ? null : (FV_TYPE) fvCache.get(reRankModel + .hashCode() + 31 * docid); + } + + public static class MapFeatureLogger extends FeatureLogger> { + + @Override + public Map makeFeatureVector(String[] featureNames, + float[] featureValues, boolean[] featuresUsed) { + Map hashmap = Collections.emptyMap(); + if (featureNames.length > 0) { + hashmap = new HashMap(featureValues.length); + for (int i = 0; i < featuresUsed.length; i++) { + if (featuresUsed[i]) { + hashmap.put(featureNames[i], featureValues[i]); + } + } + } + return hashmap; + } + + } + + public static class CSVFeatureLogger extends FeatureLogger { + StringBuilder sb = new StringBuilder(500); + char keyValueSep = ':'; + char featureSep = ';'; + + public CSVFeatureLogger setKeyValueSep(char keyValueSep) { + this.keyValueSep = keyValueSep; + return this; + } + + public CSVFeatureLogger setFeatureSep(char featureSep) { + this.featureSep = featureSep; + return this; + } + + @Override + public String makeFeatureVector(String[] featureNames, + float[] featureValues, boolean[] featuresUsed) { + for (int i = 0; i < featuresUsed.length; i++) { + if (featuresUsed[i]) { + sb.append(featureNames[i]).append(keyValueSep) + .append(featureValues[i]); + sb.append(featureSep); + } + } + + String features = (sb.length() > 0 ? sb.substring(0, sb.length() - 1) + : ""); + sb.setLength(0); + + return features; + } + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/log/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/log/package-info.java new file mode 100644 index 000000000000..8a55cc361730 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/log/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * FeatureLogger can be registered in a model and provide a strategy for logging the feature values. + */ +package org.apache.solr.ltr.log; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/Feature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/Feature.java new file mode 100644 index 000000000000..e93f9740a003 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/Feature.java @@ -0,0 +1,141 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.feature.norm.impl.IdentityNormalizer; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.NamedParams; + +/** + * A 'recipe' for computing a feature + */ +public abstract class Feature extends Query implements Cloneable { + + protected String name; + protected String type = this.getClass().getCanonicalName(); + protected Normalizer norm = IdentityNormalizer.INSTANCE; + protected int id; + protected NamedParams params = NamedParams.EMPTY; + + public void init(String name, NamedParams params, int id) + throws FeatureException { + this.name = name; + this.params = params; + this.id = id; + } + + public Feature() { + + } + + /** Returns a clone of this feature query. */ + @Override + public Query clone() { + + try { + return (Query) super.clone(); + } catch (CloneNotSupportedException e) { + // FIXME throw the exception, wrap into another exception? + e.printStackTrace(); + } + return null; + } + + @Override + public String toString(String field) { + return "Feature [name=" + name + ", type=" + type + ", id=" + id + + ", params=" + params + "]"; + } + + public abstract FeatureWeight createWeight(IndexSearcher searcher, + boolean needsScores) throws IOException; + + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + id; + result = prime * result + ((name == null) ? 0 : name.hashCode()); + result = prime * result + ((params == null) ? 0 : params.hashCode()); + result = prime * result + ((type == null) ? 0 : type.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (!super.equals(obj)) return false; + Feature other = (Feature) obj; + if (id != other.id) return false; + if (name == null) { + if (other.name != null) return false; + } else if (!name.equals(other.name)) return false; + if (params == null) { + if (other.params != null) return false; + } else if (!params.equals(other.params)) return false; + if (type == null) { + if (other.type != null) return false; + } else if (!type.equals(other.type)) return false; + return true; + } + + /** + * @return the type + */ + public String getType() { + return type; + } + + /** + * @return the name + */ + public String getName() { + return name; + } + + /** + * @return the norm + */ + public Normalizer getNorm() { + return norm; + } + + /** + * @return the id + */ + public int getId() { + return id; + } + + /** + * @return the params + */ + public NamedParams getParams() { + return params; + } + + public void setNorm(Normalizer norm) { + this.norm = norm; + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/FeatureScorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/FeatureScorer.java new file mode 100644 index 000000000000..1021967e2fce --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/FeatureScorer.java @@ -0,0 +1,70 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.HashMap; + +import org.apache.lucene.search.Scorer; + +/** + * A 'recipe' for computing a feature + */ +public abstract class FeatureScorer extends Scorer { + + protected String name; + private HashMap docInfo; + + public FeatureScorer(FeatureWeight weight) { + super(weight); + this.name = weight.getName(); + } + + @Override + public abstract float score() throws IOException; + + /** + * Used in the FeatureWeight's explain. Each feature should implement this + * returning properties of the specific scorer useful for an explain. For + * example "MyCustomClassFeature [name=" + name + "myVariable:" + myVariable + + * "]"; + */ + @Override + public abstract String toString(); + + /** + * Used to provide context from initial score steps to later reranking steps. + */ + public void setDocInfo(HashMap iDocInfo) { + docInfo = iDocInfo; + } + + public Object getDocParam(String key) { + return docInfo.get(key); + } + + public boolean hasDocParam(String key) { + if (docInfo != null) return docInfo.containsKey(key); + else return false; + } + + @Override + public int freq() throws IOException { + throw new UnsupportedOperationException(); + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/FeatureWeight.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/FeatureWeight.java new file mode 100644 index 000000000000..fa816586ee09 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/FeatureWeight.java @@ -0,0 +1,159 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.Map; +import java.util.Set; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Weight; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.feature.norm.impl.IdentityNormalizer; +import org.apache.solr.ltr.util.MacroExpander; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.request.SolrQueryRequest; + +public abstract class FeatureWeight extends Weight { + + protected String name; + protected NamedParams params = NamedParams.EMPTY; + protected Normalizer norm = IdentityNormalizer.INSTANCE; + protected IndexSearcher searcher; + protected SolrQueryRequest request; + protected Map efi; + protected MacroExpander macroExpander; + protected Query originalQuery; + protected int id; + + /** + * Initialize a feature without the normalizer from the feature file. This is + * called on initial construction since multiple models share the same + * features, but have different normalizers. A concrete model's feature is + * copied through featForNewModel(). + * + * @param q + * Solr query associated with this FeatureWeight + * @param searcher + * Solr searcher available for features if they need them + * @param name + * Name of the feature + * @param params + * Custom parameters that the feature may use + * @param norm + * Feature normalizer used to normalize the feature value + * @param id + * Unique ID for this feature. Similar to feature name, except it can + * be used to directly access the feature in the global list of + * features. + */ + public FeatureWeight(Query q, IndexSearcher searcher, String name, + NamedParams params, Normalizer norm, int id) { + super(q); + this.searcher = searcher; + this.name = name; + this.params = params; + this.id = id; + this.norm = norm; + } + + public final void setRequest(SolrQueryRequest request) { + this.request = request; + } + + public final void setExternalFeatureInfo(Map efi) { + this.efi = efi; + this.macroExpander = new MacroExpander(efi); + } + + /** + * Called once after all parameters have been set on the weight. Override this + * to do things with the original query, request, or external parameters. + */ + public void process() throws IOException {} + + public String getName() { + return name; + } + + public Normalizer getNorm() { + return norm; + } + + public NamedParams getParams() { + return params; + } + + public int getId() { + return id; + } + + public float getDefaultValue() { + return 0; + } + + @Override + public abstract FeatureScorer scorer(LeafReaderContext context) + throws IOException; + + @Override + public Explanation explain(LeafReaderContext context, int doc) + throws IOException { + FeatureScorer r = scorer(context); + r.iterator().advance(doc); + float score = getDefaultValue(); + if (r.docID() == doc) score = r.score(); + + return Explanation.match(score, r.toString()); + } + + @Override + public float getValueForNormalization() throws IOException { + return 1f; + } + + @Override + public void normalize(float norm, float topLevelBoost) { + // For advanced features that use Solr weights internally, you must override + // and pass this call on to them + } + + @Override + public void extractTerms(Set terms) { + // needs to be implemented by query subclasses + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return this.getClass().getName() + " [name=" + name + ", params=" + params + + "]"; + } + + /** + * @param originalQuery + * the originalQuery to set + */ + public final void setOriginalQuery(Query originalQuery) { + this.originalQuery = originalQuery; + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRCollector.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRCollector.java new file mode 100644 index 000000000000..836692bae66e --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRCollector.java @@ -0,0 +1,183 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.TopFieldCollector; +import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.util.BytesRef; +import org.apache.solr.common.SolrException; +import org.apache.solr.handler.component.QueryElevationComponent; +import org.apache.solr.request.SolrRequestInfo; +import org.apache.solr.search.QueryCommand; +import org.apache.solr.search.SolrIndexSearcher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.carrotsearch.hppc.IntFloatHashMap; +import com.carrotsearch.hppc.IntIntHashMap; + +@SuppressWarnings("rawtypes") +public class LTRCollector extends TopDocsCollector { + // FIXME: This should extend ReRankCollector since it is mostly a copy + + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + private ModelQuery reRankModel; + private TopDocsCollector mainCollector; + private IndexSearcher searcher; + private int reRankDocs; + private Map boostedPriority; + + @SuppressWarnings("unchecked") + public LTRCollector(int reRankDocs, ModelQuery reRankModel, QueryCommand cmd, + IndexSearcher searcher, Map boostedPriority) + throws IOException { + super(null); + this.reRankModel = reRankModel; + this.reRankDocs = reRankDocs; + this.boostedPriority = boostedPriority; + Sort sort = cmd.getSort(); + if (sort == null) { + this.mainCollector = TopScoreDocCollector.create(this.reRankDocs); + } else { + sort = sort.rewrite(searcher); + this.mainCollector = TopFieldCollector.create(sort, this.reRankDocs, + false, true, true); + } + this.searcher = searcher; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) + throws IOException { + return mainCollector.getLeafCollector(context); + } + + @Override + public boolean needsScores() { + return true; + } + + @Override + protected int topDocsSize() { + return reRankDocs; + } + + @Override + public int getTotalHits() { + return mainCollector.getTotalHits(); + } + + @SuppressWarnings("unchecked") + @Override + public TopDocs topDocs(int start, int howMany) { + try { + if (howMany > reRankDocs) { + howMany = reRankDocs; + } + + TopDocs mainDocs = mainCollector.topDocs(0, reRankDocs); + TopDocs topRerankDocs; + try { + topRerankDocs = new LTRRescorer(reRankModel).rescore(searcher, + mainDocs, howMany); + } catch (IOException e) { + logger.error("LTRRescorer reranking failed. " + e); + e.printStackTrace(); + // If someone deployed a messed up model, we don't want to crash and + // burn. + // Return the original list at least + topRerankDocs = mainDocs; + } + + if (boostedPriority != null) { + SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); + Map requestContext = null; + if (info != null) { + requestContext = info.getReq().getContext(); + } + + IntIntHashMap boostedDocs = QueryElevationComponent.getBoostDocs( + (SolrIndexSearcher) searcher, boostedPriority, requestContext); + + Arrays.sort(topRerankDocs.scoreDocs, new BoostedComp(boostedDocs, + mainDocs.scoreDocs, topRerankDocs.getMaxScore())); + + ScoreDoc[] scoreDocs = new ScoreDoc[howMany]; + System.arraycopy(topRerankDocs.scoreDocs, 0, scoreDocs, 0, howMany); + topRerankDocs.scoreDocs = scoreDocs; + } + + return topRerankDocs; + + } catch (Exception e) { + logger.error("Exception: ",e); + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); + } + } + + public class BoostedComp implements Comparator { + IntFloatHashMap boostedMap; + + public BoostedComp(IntIntHashMap boostedDocs, ScoreDoc[] scoreDocs, + float maxScore) { + this.boostedMap = new IntFloatHashMap(boostedDocs.size() * 2); + + for (int i = 0; i < scoreDocs.length; i++) { + final int idx; + if ((idx = boostedDocs.indexOf(scoreDocs[i].doc)) >= 0) { + boostedMap + .put(scoreDocs[i].doc, maxScore + boostedDocs.indexGet(idx)); + } else { + break; + } + } + } + + public int compare(Object o1, Object o2) { + ScoreDoc doc1 = (ScoreDoc) o1; + ScoreDoc doc2 = (ScoreDoc) o2; + float score1 = doc1.score; + float score2 = doc2.score; + int idx; + if ((idx = boostedMap.indexOf(doc1.doc)) >= 0) { + score1 = boostedMap.indexGet(idx); + } + + if ((idx = boostedMap.indexOf(doc2.doc)) >= 0) { + score2 = boostedMap.indexGet(idx); + } + + return -Float.compare(score1, score2); + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRComponent.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRComponent.java new file mode 100644 index 000000000000..250ed2775dda --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRComponent.java @@ -0,0 +1,92 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; + +import org.apache.solr.common.SolrException; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrCore; +import org.apache.solr.handler.component.ResponseBuilder; +import org.apache.solr.handler.component.SearchComponent; +import org.apache.solr.ltr.rest.ManagedFeatureStore; +import org.apache.solr.ltr.rest.ManagedModelStore; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceObserver; +import org.apache.solr.util.plugin.SolrCoreAware; + +/** + * The FeatureVectorComponent is intended to be used for offline training of + * your model in order to fetch the feature vectors of the top matching + * documents. + */ +public class LTRComponent extends SearchComponent implements SolrCoreAware, + ManagedResourceObserver { + + // TODO: This is the Solr way, move these to LTRParams in solr.common.params + public interface LTRParams { + // Set to true to turn on feature vectors in the LTRComponent + public static final String FV = "fv"; + public static final String FV_RESPONSE_WRITER = "fvwt"; + public static final String FSTORE_END_POINT = "/schema/fstore"; + public static final String MSTORE_END_POINT = "/schema/mstore"; + + } + + public static final String LOGGER_NAME = "solr-feature-logger"; + public static final String FEATURE_PARAM = "featureVectors"; + + @SuppressWarnings("rawtypes") + @Override + public void init(NamedList args) {} + + @Override + public void prepare(ResponseBuilder rb) throws IOException {} + + @Override + public void process(ResponseBuilder rb) throws IOException {} + + @Override + public String getDescription() { + return "Manages models and features in Solr"; + } + + @Override + public void onManagedResourceInitialized(NamedList args, + ManagedResource res) throws SolrException { + // FIXME do we need this? + } + + @Override + public void inform(SolrCore core) { + core.getRestManager().addManagedResource(LTRParams.FSTORE_END_POINT, + ManagedFeatureStore.class); + ManagedFeatureStore fr = (ManagedFeatureStore) core.getRestManager() + .getManagedResource(LTRParams.FSTORE_END_POINT); + core.getRestManager().addManagedResource(LTRParams.MSTORE_END_POINT, + ManagedModelStore.class); + + ManagedModelStore mr = (ManagedModelStore) core.getRestManager() + .getManagedResource(LTRParams.MSTORE_END_POINT); + // core.getResourceLoader().getManagedResourceRegistry().registerManagedResource(LTRParams.FSTORE_END_POINT, + // , observer); + mr.init(fr); + // now we can safely load the models + mr.loadStoredModels(); + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRFeatureLoggerTransformerFactory.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRFeatureLoggerTransformerFactory.java new file mode 100644 index 000000000000..dfcf082cb0e4 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRFeatureLoggerTransformerFactory.java @@ -0,0 +1,147 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.solr.common.SolrDocument; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.SolrException.ErrorCode; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.ltr.log.FeatureLogger; +import org.apache.solr.ltr.ranking.LTRQParserPlugin.LTRQParser; +import org.apache.solr.ltr.ranking.ModelQuery.ModelWeight; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.ResultContext; +import org.apache.solr.response.transform.DocTransformer; +import org.apache.solr.response.transform.TransformerFactory; +import org.apache.solr.search.SolrIndexSearcher; + +/** + * This transformer will take care to generate and append in the response the + * features declared in the feature store of the current model. The class is + * useful if you are not interested in the reranking (e.g., bootstrapping a + * machine learning framework). + */ +public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { + + SolrQueryRequest req; + + @Override + public void init(@SuppressWarnings("rawtypes") NamedList args) { + super.init(args); + } + + @Override + public DocTransformer create(String name, SolrParams params, + SolrQueryRequest req) { + this.req = req; + return new FeatureTransformer(name); + } + + class FeatureTransformer extends DocTransformer { + + String name; + List leafContexts; + SolrIndexSearcher searcher; + ModelQuery reRankModel; + ModelWeight modelWeight; + FeatureLogger featurelLogger; + + /** + * @param name + * Name of the field to be added in a document representing the + * feature vectors + */ + public FeatureTransformer(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + + @Override + public void setContext(ResultContext context) { + super.setContext(context); + if (context == null) return; + if (context.getRequest() == null) return; + reRankModel = (ModelQuery) req.getContext() + .get(LTRQParser.MODEL); + if (reRankModel == null) throw new SolrException( + org.apache.solr.common.SolrException.ErrorCode.BAD_REQUEST, + "model is null"); + reRankModel.setRequest(context.getRequest()); + featurelLogger = reRankModel.getFeatureLogger(); + searcher = context.getSearcher(); + if (searcher == null) throw new SolrException( + org.apache.solr.common.SolrException.ErrorCode.BAD_REQUEST, + "searcher is null"); + leafContexts = searcher.getTopReaderContext().leaves(); + Weight w; + try { + w = reRankModel.createWeight(searcher, true); + } catch (IOException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e.getMessage(), e); + } + if (w == null || !(w instanceof ModelWeight)) { + throw new SolrException(ErrorCode.BAD_REQUEST, + "error logging the features, model weight is null"); + } + modelWeight = (ModelWeight) w; + + } + + @Override + public void transform(SolrDocument doc, int docid, float score) + throws IOException { + Object fv = featurelLogger.getFeatureVector(docid, reRankModel, searcher); + if (fv == null) { // FV for this document was not in the cache + int n = ReaderUtil.subIndex(docid, leafContexts); + final LeafReaderContext atomicContext = leafContexts.get(n); + int deBasedDoc = docid - atomicContext.docBase; + Scorer r = modelWeight.scorer(atomicContext); + if ((r == null || r.iterator().advance(deBasedDoc) != docid) + && fv == null) { + doc.addField(name, featurelLogger.makeFeatureVector(new String[0], + new float[0], new boolean[0])); + } else { + float finalScore = r.score(); + String[] names = modelWeight.allFeatureNames; + float[] values = modelWeight.allFeatureValues; + boolean[] valuesUsed = modelWeight.allFeaturesUsed; + doc.addField(name, + featurelLogger.makeFeatureVector(names, values, valuesUsed)); + } + } else { + doc.addField(name, fv); + } + + } + + } + +} \ No newline at end of file diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRQParserPlugin.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRQParserPlugin.java new file mode 100644 index 000000000000..07d56ba152e9 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRQParserPlugin.java @@ -0,0 +1,158 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.lang.invoke.MethodHandles; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +import org.apache.lucene.search.Query; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.SolrException.ErrorCode; +import org.apache.solr.common.params.CommonParams; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.ltr.feature.ModelMetadata; +import org.apache.solr.ltr.log.FeatureLogger; +import org.apache.solr.ltr.ranking.LTRComponent.LTRParams; +import org.apache.solr.ltr.rest.ManagedModelStore; +import org.apache.solr.ltr.util.ModelException; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.search.QParser; +import org.apache.solr.search.QParserPlugin; +import org.apache.solr.search.SyntaxError; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Plug into solr a rerank model. + * + * Learning to Rank Query Parser Syntax: rq={!ltr model=6029760550880411648 + * reRankDocs=300 efi.myCompanyQueryIntent=0.98} + * + */ +public class LTRQParserPlugin extends QParserPlugin { + public static final String NAME = "ltr"; + + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + + @Override + public void init(@SuppressWarnings("rawtypes") NamedList args) {} + + @Override + public QParser createParser(String qstr, SolrParams localParams, + SolrParams params, SolrQueryRequest req) { + return new LTRQParser(qstr, localParams, params, req); + } + + public class LTRQParser extends QParser { + // param for setting the model + public static final String MODEL = "model"; + + // param for setting how many documents the should be reranked + public static final String RERANK_DOCS = "reRankDocs"; + + // params for setting custom external info that features can use, like query + // intent + // TODO: Can we just pass the entire request all the way down to all + // models/features? + public static final String EXTERNAL_FEATURE_INFO = "efi."; + + ManagedModelStore mr = null; + + public LTRQParser(String qstr, SolrParams localParams, SolrParams params, + SolrQueryRequest req) { + super(qstr, localParams, params, req); + + mr = (ManagedModelStore) req.getCore().getRestManager() + .getManagedResource(LTRParams.MSTORE_END_POINT); + } + + @Override + public Query parse() throws SyntaxError { + // ReRanking Model + String modelName = localParams.get(MODEL); + if (modelName == null || modelName.isEmpty()) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "Must provide model in the request"); + } + + ModelQuery reRankModel = null; + try { + ModelMetadata meta = mr.getModel(modelName); + reRankModel = new ModelQuery(meta); + } catch (ModelException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } + + // String[] fl = req.getParams().getParams(CommonParams.FL); Contains the + // [transformer] + // ReRank doc count + // Allow reranking more docs than shown, since the nth doc might be the + // best one after reranking, + // but not showing more than is reranked. + int reRankDocs = localParams.getInt(RERANK_DOCS, 200); + int start = params.getInt(CommonParams.START, 0); + int rows = params.getInt(CommonParams.ROWS, 10); + // Feature Vectors + // FIXME: Exception if feature vectors requested without specifying what + // features to return?? + // For training a new model offline you need feature vectors, but dont yet + // have a model. Should provide the FeatureStore name as an arg to the + // feature vector + // transformer and remove the duplicate fv=true arg + boolean returnFeatureVectors = params.getBool(LTRParams.FV, false); + + if (returnFeatureVectors) { + + FeatureLogger solrLogger = FeatureLogger.getFeatureLogger(params + .get(LTRParams.FV_RESPONSE_WRITER)); + reRankModel.setFeatureLogger(solrLogger); + req.getContext().put(LTRComponent.LOGGER_NAME, solrLogger); + req.getContext().put(MODEL, reRankModel); + } + + if (start + rows > reRankDocs) { + throw new SolrException(ErrorCode.BAD_REQUEST, + "Requesting more documents than being reranked."); + } + reRankDocs = Math.max(start + rows, reRankDocs); + + // External features + Map externalFeatureInfo = new HashMap<>(); + for (Iterator it = localParams.getParameterNamesIterator(); it + .hasNext();) { + final String name = it.next(); + if (name.startsWith(EXTERNAL_FEATURE_INFO)) { + externalFeatureInfo.put( + name.substring(EXTERNAL_FEATURE_INFO.length()), + localParams.get(name)); + } + } + reRankModel.setExternalFeatureInfo(externalFeatureInfo); + + logger.info("Reranking {} docs using model {}", reRankDocs, reRankModel + .getMetadata().getName()); + reRankModel.setRequest(req); + + return new LTRQuery(reRankModel, reRankDocs); + } + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRQuery.java new file mode 100644 index 000000000000..a45ea9d18581 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRQuery.java @@ -0,0 +1,167 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.Map; +import java.util.Set; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.BytesRef; +import org.apache.solr.handler.component.MergeStrategy; +import org.apache.solr.handler.component.QueryElevationComponent; +import org.apache.solr.request.SolrRequestInfo; +import org.apache.solr.search.QueryCommand; +import org.apache.solr.search.RankQuery; +import org.apache.solr.search.SolrIndexSearcher; + +/** + * The LTRQuery and LTRWeight wrap the main query to fetch matching docs. It + * then provides its own TopDocsCollector, which goes through the top X docs and + * reranks them using the provided reRankModel. + */ +public class LTRQuery extends RankQuery { + private Query mainQuery = new MatchAllDocsQuery(); + private ModelQuery reRankModel; + private int reRankDocs; + private Map boostedPriority; + + public LTRQuery(ModelQuery reRankModel, int reRankDocs) { + this.reRankModel = reRankModel; + this.reRankDocs = reRankDocs; + } + + @Override + public int hashCode() { + return (mainQuery.hashCode() + reRankModel.hashCode() + reRankDocs); + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) return false; + LTRQuery rrq = (LTRQuery) o; + return (mainQuery.equals(rrq.mainQuery) + && reRankModel.equals(rrq.reRankModel) && reRankDocs == rrq.reRankDocs); + } + + @Override + public RankQuery wrap(Query _mainQuery) { + if (_mainQuery != null) { + this.mainQuery = _mainQuery; + } + + reRankModel.setOriginalQuery(mainQuery); + + return this; + } + + @Override + public MergeStrategy getMergeStrategy() { + return null; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @Override + public TopDocsCollector getTopDocsCollector(int len, QueryCommand cmd, + IndexSearcher searcher) throws IOException { + + if (this.boostedPriority == null) { + SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); + if (info != null) { + Map context = info.getReq().getContext(); + this.boostedPriority = (Map) context + .get(QueryElevationComponent.BOOSTED_PRIORITY); + // https://github.com/apache/lucene-solr/blob/5775be6e6242c0f7ec108b10ebbf9da3a7d07a4b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/TFValueSource.java#L56 + // function query needs the searcher in the context + context.put("searcher", searcher); + } + } + + return new LTRCollector(reRankDocs, reRankModel, cmd, searcher, + boostedPriority); + // return new LTRCollector(reRankDocs, reRankModel, cmd, searcher, + // boostedPriority); + } + + @Override + public String toString(String field) { + return "{!ltr mainQuery='" + mainQuery.toString() + "' reRankModel='" + + reRankModel.toString() + "' reRankDocs=" + reRankDocs + "}"; + } + + @Override + public Weight createWeight(IndexSearcher searcher, boolean needsScores) + throws IOException { + Weight mainWeight = mainQuery.createWeight(searcher, needsScores); + return new LTRWeight(searcher, mainWeight, reRankModel); + } + + /** + * This is the weight for the main solr query in the LTRQuery. The only thing + * this really does is have an explain using the reRankQuery. + */ + public class LTRWeight extends Weight { + private ModelQuery reRankModel; + private Weight mainWeight; + private IndexSearcher searcher; + + public LTRWeight(IndexSearcher searcher, Weight mainWeight, + ModelQuery reRankModel) throws IOException { + super(LTRQuery.this); + this.reRankModel = reRankModel; + this.mainWeight = mainWeight; + this.searcher = searcher; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) + throws IOException { + Explanation mainExplain = mainWeight.explain(context, doc); + return new LTRRescorer(reRankModel).explain(searcher, mainExplain, + context.docBase + doc); + } + + @Override + public void extractTerms(Set terms) { + mainWeight.extractTerms(terms); + } + + @Override + public float getValueForNormalization() throws IOException { + return mainWeight.getValueForNormalization(); + } + + @Override + public void normalize(float norm, float topLevelBoost) { + mainWeight.normalize(norm, topLevelBoost); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + return mainWeight.scorer(context); + } + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRRescorer.java new file mode 100644 index 000000000000..9e6b53b9e0f1 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LTRRescorer.java @@ -0,0 +1,265 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Rescorer; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.solr.ltr.log.FeatureLogger; +import org.apache.solr.ltr.ranking.ModelQuery.ModelWeight; +import org.apache.solr.ltr.ranking.ModelQuery.ModelWeight.ModelScorer; +import org.apache.solr.search.SolrIndexSearcher; + +/** + * Implements the rescoring logic. The top documents returned by solr with their + * original scores, will be processed by a {@link ModelQuery} that will assign a + * new score to each document. The top documents will be resorted based on the + * new score. + * */ +public class LTRRescorer extends Rescorer { + + ModelQuery reRankModel; + public static final String ORIGINAL_DOC_NAME = "ORIGINAL_DOC_SCORE"; + + public LTRRescorer(ModelQuery reRankModel) { + this.reRankModel = reRankModel; + } + + private void heapAdjust(ScoreDoc[] hits, int size, int root) { + ScoreDoc doc = hits[root]; + float score = doc.score; + int i = root; + while (i <= (size >> 1) - 1) { + int lchild = (i << 1) + 1; + ScoreDoc ldoc = hits[lchild]; + float lscore = ldoc.score; + float rscore = Float.MAX_VALUE; + int rchild = (i << 1) + 2; + ScoreDoc rdoc = null; + if (rchild < size) { + rdoc = hits[rchild]; + rscore = rdoc.score; + } + if (lscore < score) { + if (rscore < lscore) { + hits[i] = rdoc; + hits[rchild] = doc; + i = rchild; + } else { + hits[i] = ldoc; + hits[lchild] = doc; + i = lchild; + } + } else if (rscore < score) { + hits[i] = rdoc; + hits[rchild] = doc; + i = rchild; + } else { + return; + } + } + } + + private void heapify(ScoreDoc[] hits, int size) { + for (int i = (size >> 1) - 1; i >= 0; i--) { + heapAdjust(hits, size, i); + } + } + + /** + * rescores the documents: + * + * @param searcher + * current IndexSearcher + * @param firstPassTopDocs + * documents to rerank; + * @param topN + * documents to return; + */ + @Override + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, + int topN) throws IOException { + if (topN == 0 || firstPassTopDocs.totalHits == 0) { + return firstPassTopDocs; + } + + ScoreDoc[] hits = firstPassTopDocs.scoreDocs; + + Arrays.sort(hits, new Comparator() { + @Override + public int compare(ScoreDoc a, ScoreDoc b) { + return a.doc - b.doc; + } + }); + + topN = Math.min(topN, firstPassTopDocs.totalHits); + ScoreDoc[] reranked = new ScoreDoc[topN]; + String[] featureNames; + float[] featureValues; + boolean[] featuresUsed; + + List leaves = searcher.getIndexReader().leaves(); + + int readerUpto = -1; + int endDoc = 0; + int docBase = 0; + + ModelScorer scorer = null; + int hitUpto = 0; + + ModelWeight modelWeight = (ModelWeight) searcher.createNormalizedWeight( + reRankModel, true); + FeatureLogger featureLogger = reRankModel.getFeatureLogger(); + + // FIXME: I dislike that we have no gaurentee this is actually a + // SolrIndexReader. + // We should do something about that + SolrIndexSearcher solrIndexSearch = (SolrIndexSearcher) searcher; + + // FIXME + // All of this heap code is only for logging. Wrap all this code in + // 1 outer if (fl != null) so we can skip heap stuff if the request doesn't + // call for a feature vector. + // + // that could be done but it would require a new vector of size $rerank, + // that in the end we would have to sort, while using the heap, also if + // we do not log, in the end we sort a smaller array of topN results (that + // is the heap array). + // The heap is just anticipating the sorting of the array, so I don't think + // it would + // save time. + + while (hitUpto < hits.length) { + ScoreDoc hit = hits[hitUpto]; + int docID = hit.doc; + + LeafReaderContext readerContext = null; + while (docID >= endDoc) { + readerUpto++; + readerContext = leaves.get(readerUpto); + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); + } + + // We advanced to another segment + if (readerContext != null) { + docBase = readerContext.docBase; + scorer = modelWeight.scorer(readerContext); + } + + // Scorer for a ModelWeight should never be null since we always have to + // call score + // even if no feature scorers match, since a model might use that info to + // return a + // non-zero score. Same applies for the case of advancing a ModelScorer + // past the target + // doc since the model algorithm still needs to compute a potentially + // non-zero score from blank features. + assert (scorer != null); + int targetDoc = docID - docBase; + int actualDoc = scorer.docID(); + actualDoc = scorer.iterator().advance(targetDoc); + + scorer.setDocInfoParam(ORIGINAL_DOC_NAME, new Float(hit.score)); + hit.score = scorer.score(); + featureNames = modelWeight.allFeatureNames; + featureValues = modelWeight.allFeatureValues; + featuresUsed = modelWeight.allFeaturesUsed; + + if (hitUpto < topN) { + reranked[hitUpto] = hit; + // if the heap is not full, maybe I want to log the features for this + // document + if (featureLogger != null) { + featureLogger.log(hit.doc, reRankModel, solrIndexSearch, + featureNames, featureValues, featuresUsed); + } + } else if (hitUpto == topN) { + // collected topN document, I create the heap + heapify(reranked, topN); + } + if (hitUpto >= topN) { + // once that heap is ready, if the score of this document is lower that + // the minimum + // i don't want to log the feature. Otherwise I replace it with the + // minimum and fix the + // heap. + if (hit.score > reranked[0].score) { + reranked[0] = hit; + heapAdjust(reranked, topN, 0); + if (featureLogger != null) { + featureLogger.log(hit.doc, reRankModel, solrIndexSearch, + featureNames, featureValues, featuresUsed); + } + } + } + + hitUpto++; + } + + // Must sort all documents that we reranked, and then select the top N + + // ScoreDoc[] reranked = heap.getArray(); + Arrays.sort(reranked, new Comparator() { + @Override + public int compare(ScoreDoc a, ScoreDoc b) { + // Sort by score descending, then docID ascending: + if (a.score > b.score) { + return -1; + } else if (a.score < b.score) { + return 1; + } else { + // This subtraction can't overflow int + // because docIDs are >= 0: + return a.doc - b.doc; + } + } + }); + + // if (topN < hits.length) { + // ScoreDoc[] subset = new ScoreDoc[topN]; + // System.arraycopy(hits, 0, subset, 0, topN); + // hits = subset; + // } + + return new TopDocs(firstPassTopDocs.totalHits, reranked, reranked[0].score); + } + + @Override + public Explanation explain(IndexSearcher searcher, + Explanation firstPassExplanation, int docID) throws IOException { + + List leafContexts = searcher.getTopReaderContext() + .leaves(); + int n = ReaderUtil.subIndex(docID, leafContexts); + final LeafReaderContext context = leafContexts.get(n); + int deBasedDoc = docID - context.docBase; + Weight modelWeight = searcher.createNormalizedWeight(reRankModel, true); + return modelWeight.explain(context, deBasedDoc); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LambdaMARTModel.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LambdaMARTModel.java new file mode 100644 index 000000000000..6096a035a81d --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/LambdaMARTModel.java @@ -0,0 +1,242 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.solr.ltr.feature.ModelMetadata; +import org.apache.solr.ltr.util.ModelException; +import org.apache.solr.ltr.util.NamedParams; + +public class LambdaMARTModel extends ModelMetadata { + + List trees = new ArrayList(); + + class RegressionTreeNode { + public static final float NODE_SPLIT_SLACK = 1E-6f; + public float value; + public String feature; + public int featureIndex; + public float threshold; + public RegressionTreeNode left = null; + public RegressionTreeNode right = null; + + public boolean isLeaf() { + return feature == null; + } + + public float score(float[] featureVector) { + if (isLeaf()) { + return value; + } + + if (featureIndex < 0 || // unsupported feature + featureIndex > featureVector.length || // tree is looking for a + // feature that does not + // exist + featureVector[featureIndex] <= threshold) { + return left.score(featureVector); + } + + return right.score(featureVector); + } + + public String explain(float[] featureVector) { + if (isLeaf()) { + return "val: " + value; + } + + String rval = ""; + + // could store extra information about how much training data supported + // each branch and report + // that here + + if (featureIndex < 0 || featureIndex > featureVector.length) { + rval += "'" + feature + "' does not exist in FV, Go Left | "; + return rval + left.explain(featureVector); + } else if (featureVector[featureIndex] <= threshold) { + rval += "'" + feature + "':" + featureVector[featureIndex] + " <= " + + threshold + ", Go Left | "; + return rval + left.explain(featureVector); + } + + rval += "'" + feature + "':" + featureVector[featureIndex] + " > " + + threshold + ", Go Right | "; + + return rval + right.explain(featureVector); + } + + public RegressionTreeNode(Map map, + HashMap fname2index) throws ModelException { + if (map.containsKey("value")) { + value = NamedParams.convertToFloat(map.get("value")); + } else { + + Object of = map.get("feature"); + if (null == of) { + throw new ModelException( + "LambdaMARTModel tree node is missing feature"); + } + + feature = (String) of; + Integer idx = fname2index.get(feature); + // this happens if the tree specifies a feature that does not exist + // this could be due to lambdaSmart building off of pre-existing trees + // that use a feature that is no longer output during feature extraction + // TODO: make lambdaSmart (in rank_svm_final repo ) + // either remove trees that depend on such features + // or prune them back above the split on that feature + featureIndex = (idx == null) ? -1 : idx; + + Object ot = map.get("threshold"); + if (null == ot) { + throw new ModelException( + "LambdaMARTModel tree node is missing threshold"); + } + + threshold = NamedParams.convertToFloat(ot) + NODE_SPLIT_SLACK; + + Object ol = map.get("left"); + if (null == ol) { + throw new ModelException("LambdaMARTModel tree node is missing left"); + } + + left = new RegressionTreeNode((Map) ol, fname2index); + + Object or = map.get("right"); + if (null == or) { + throw new ModelException("LambdaMARTModel tree node is missing right"); + } + + right = new RegressionTreeNode((Map) or, fname2index); + } + } + + } + + class RegressionTree { + public float weight; + public RegressionTreeNode root; + + public float score(float[] featureVector) { + return weight * root.score(featureVector); + } + + public String explain(float[] featureVector) { + return root.explain(featureVector); + } + + public RegressionTree(Map map, + HashMap fname2index) throws ModelException { + Object ow = map.get("weight"); + if (null == ow) { + throw new ModelException( + "LambdaMARTModel tree doesn't contain a weight"); + } + + weight = NamedParams.convertToFloat(ow); + + Object ot = map.get("tree"); + + if (null == ot) { + throw new ModelException("LambdaMARTModel tree doesn't contain a tree"); + } + + root = new RegressionTreeNode((Map) ot, fname2index); + } + } + + public LambdaMARTModel(String name, String type, List features, + String featureStoreName, Collection allFeatures, + NamedParams params) throws ModelException { + super(name, type, features, featureStoreName, allFeatures, params); + + if (!hasParams()) { + throw new ModelException("LambdaMARTModel doesn't contain any params"); + } + + HashMap fname2index = new HashMap(); + for (int i = 0; i < features.size(); ++i) { + String key = features.get(i).getName(); + fname2index.put(key, i); + } + + List jsonTrees = getParams().getList("trees"); + + if (jsonTrees == null || jsonTrees.isEmpty()) { + throw new ModelException("LambdaMARTModel doesn't contain any trees"); + } + + for (Object o : jsonTrees) { + Map t = (Map) o; + RegressionTree rt = new RegressionTree(t, fname2index); + trees.add(rt); + } + + } + + @Override + public float score(float[] modelFeatureValuesNormalized) { + float score = 0; + for (RegressionTree t : trees) { + score += t.score(modelFeatureValuesNormalized); + } + return score; + } + + // ///////////////////////////////////////// + // produces a string that looks like: + // 40.0 = lambdamartmodel [ org.apache.solr.ltr.ranking.LambdaMARTModel ] + // model applied to + // features, sum of: + // 50.0 = tree 0 | 'matchedTitle':1.0 > 0.500001, Go Right | + // 'this_feature_doesnt_exist' does not + // exist in FV, Go Left | val: 50.0 + // -10.0 = tree 1 | val: -10.0 + public Explanation explain(LeafReaderContext context, int doc, + float finalScore, List featureExplanations) { + // FIXME this still needs lots of work + float[] fv = new float[featureExplanations.size()]; + int index = 0; + for (Explanation featureExplain : featureExplanations) { + fv[index] = featureExplain.getValue(); + index++; + } + + List details = new ArrayList<>(); + index = 0; + + for (RegressionTree t : trees) { + float score = t.score(fv); + Explanation p = Explanation.match(score, + "tree " + index + " | " + t.explain(fv)); + details.add(p); + index++; + } + + return Explanation.match(finalScore, getName() + " [ " + getType() + + " ] model applied to features, sum of:", details); + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/ModelQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/ModelQuery.java new file mode 100644 index 000000000000..f54d728ccb58 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/ModelQuery.java @@ -0,0 +1,540 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DisjunctionDISIApproximation; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.Scorer.ChildScorer; +import org.apache.solr.ltr.feature.ModelMetadata; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.feature.norm.impl.IdentityNormalizer; +import org.apache.solr.ltr.log.FeatureLogger; +import org.apache.solr.request.SolrQueryRequest; + +/** + * The ranking query that is run, reranking results using the ModelMetadata + * algorithm + */ +public class ModelQuery extends Query { + + // contains a description of the model + protected ModelMetadata meta; + // feature logger to output the features. + private FeatureLogger fl = null; + // Map of external parameters, such as query intent, that can be used by + // features + protected Map efi; + // Original solr query used to fetch matching documents + protected Query originalQuery; + // Original solr request + protected SolrQueryRequest request; + + public ModelQuery(ModelMetadata meta) { + this.meta = meta; + } + + public ModelMetadata getMetadata() { + return meta; + } + + public void setFeatureLogger(FeatureLogger fl) { + this.fl = fl; + } + + public FeatureLogger getFeatureLogger() { + return this.fl; + } + + public Collection getAllFeatures() { + return meta.getAllFeatures(); + } + + public void setOriginalQuery(Query mainQuery) { + this.originalQuery = mainQuery; + } + + public void setExternalFeatureInfo(Map externalFeatureInfo) { + this.efi = externalFeatureInfo; + } + + public void setRequest(SolrQueryRequest request) { + this.request = request; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + ((meta == null) ? 0 : meta.hashCode()); + result = prime * result + + ((originalQuery == null) ? 0 : originalQuery.hashCode()); + result = prime * result + ((efi == null) ? 0 : originalQuery.hashCode()); + result = prime * result + this.toString().hashCode(); + return result; + } + + @Override + public boolean equals(Object obj) { + if (!super.equals(obj)) return false; + ModelQuery other = (ModelQuery) obj; + if (meta == null) { + if (other.meta != null) return false; + } else if (!meta.equals(other.meta)) return false; + if (originalQuery == null) { + if (other.originalQuery != null) return false; + } else if (!originalQuery.equals(other.originalQuery)) return false; + return true; + } + + public SolrQueryRequest getRequest() { + return request; + } + + public List getFeatures() { + return meta.getFeatures(); + } + + @Override + public ModelWeight createWeight(IndexSearcher searcher, boolean needsScores) + throws IOException { + Collection features = this.getAllFeatures(); + List modelFeatures = this.getFeatures(); + + return new ModelWeight(searcher, getWeights(modelFeatures, searcher, + needsScores), getWeights(features, searcher, needsScores)); + } + + private FeatureWeight[] getWeights(Collection features, + IndexSearcher searcher, boolean needsScores) throws IOException { + FeatureWeight[] arr = new FeatureWeight[features.size()]; + int i = 0; + SolrQueryRequest req = this.getRequest(); + // since the feature store is a linkedhashmap order is preserved + for (Feature f : features) { + FeatureWeight fw = f.createWeight(searcher, needsScores); + fw.setRequest(req); + fw.setOriginalQuery(originalQuery); + fw.setExternalFeatureInfo(efi); + fw.process(); + arr[i] = fw; + ++i; + } + return arr; + } + + @Override + public String toString(String field) { + return field; + } + + public class ModelWeight extends Weight { + + IndexSearcher searcher; + + // List of the model's features used for scoring. This is a subset of the + // features used for logging. + FeatureWeight[] modelFeatures; + float[] modelFeatureValuesNormalized; + + // List of all the feature values, used for both scoring and logging + FeatureWeight[] allFeatureWeights; + float[] allFeatureValues; + String[] allFeatureNames; + boolean[] allFeaturesUsed; + + public ModelWeight(IndexSearcher searcher, FeatureWeight[] modelFeatures, + FeatureWeight[] allFeatures) { + super(ModelQuery.this); + this.searcher = searcher; + this.allFeatureWeights = allFeatures; + this.modelFeatures = modelFeatures; + this.modelFeatureValuesNormalized = new float[modelFeatures.length]; + this.allFeatureValues = new float[allFeatures.length]; + this.allFeatureNames = new String[allFeatures.length]; + this.allFeaturesUsed = new boolean[allFeatures.length]; + + for (int i = 0; i < allFeatures.length; ++i) { + allFeatureNames[i] = allFeatures[i].getName(); + } + } + + /** + * Goes through all the stored feature values, and calculates the normalized + * values for all the features that will be used for scoring. + */ + public void normalize() { + int pos = 0; + for (FeatureWeight feature : modelFeatures) { + int featureId = feature.getId(); + if (allFeaturesUsed[featureId]) { + Normalizer norm = feature.getNorm(); + modelFeatureValuesNormalized[pos] = norm + .normalize(allFeatureValues[featureId]); + } else { + modelFeatureValuesNormalized[pos] = feature.getDefaultValue(); + } + pos++; + } + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) + throws IOException { + // FIXME: This explain doens't skip null scorers like the scorer() + // function + Explanation[] explanations = new Explanation[allFeatureValues.length]; + int index = 0; + for (FeatureWeight feature : allFeatureWeights) { + explanations[index++] = feature.explain(context, doc); + } + + List featureExplanations = new ArrayList<>(); + for (FeatureWeight f : modelFeatures) { + Normalizer n = f.getNorm(); + Explanation e = explanations[f.id]; + if (n != IdentityNormalizer.INSTANCE) e = n.explain(e); + featureExplanations.add(e); + } + // TODO this calls twice the scorers, could be optimized. + ModelScorer bs = scorer(context); + bs.iterator().advance(doc); + + float finalScore = bs.score(); + + return meta.explain(context, doc, finalScore, featureExplanations); + + } + + @Override + public float getValueForNormalization() throws IOException { + return 1; + } + + @Override + public void normalize(float norm, float topLevelBoost) { + for (FeatureWeight feature : allFeatureWeights) { + feature.normalize(norm, topLevelBoost); + } + } + + @Override + public void extractTerms(Set terms) { + for (FeatureWeight feature : allFeatureWeights) { + feature.extractTerms(terms); + } + } + + protected void reset() { + for (int i = 0, len = allFeaturesUsed.length; i < len; i++) { + allFeaturesUsed[i] = false; + } + } + + @Override + public ModelScorer scorer(LeafReaderContext context) throws IOException { + List featureScorers = new ArrayList( + allFeatureWeights.length); + for (int i = 0; i < allFeatureWeights.length; i++) { + FeatureScorer scorer = allFeatureWeights[i].scorer(context); + if (scorer != null) { + featureScorers.add(allFeatureWeights[i].scorer(context)); + } + } + + // Always return a ModelScorer, even if no features match, because we + // always need to call + // score on the model for every document, since 0 features matching could + // return a + // non 0 score for a given model. + return new ModelScorer(this, featureScorers); + } + + public class ModelScorer extends Scorer { + protected HashMap docInfo; + protected Scorer featureTraversalScorer; + + public ModelScorer(Weight weight, List featureScorers) { + super(weight); + docInfo = new HashMap(); + for (FeatureScorer subSocer : featureScorers) { + subSocer.setDocInfo(docInfo); + } + + if (featureScorers.size() <= 1) { // TODO: Allow the use of dense + // features in other cases + featureTraversalScorer = new DenseModelScorer(weight, featureScorers); + } else { + featureTraversalScorer = new SparseModelScorer(weight, featureScorers); + } + } + + @Override + public Collection getChildren() { + return featureTraversalScorer.getChildren(); + } + + public void setDocInfoParam(String key, Object value) { + docInfo.put(key, value); + } + + @Override + public int docID() { + return featureTraversalScorer.docID(); + } + + @Override + public float score() throws IOException { + return featureTraversalScorer.score(); + } + + @Override + public int freq() throws IOException { + return featureTraversalScorer.freq(); + } + + @Override + public DocIdSetIterator iterator() { + return featureTraversalScorer.iterator(); + } + + public class SparseModelScorer extends Scorer { + protected DisiPriorityQueue subScorers; + protected ModelQuerySparseIterator itr; + + protected int targetDoc = -1; + protected int activeDoc = -1; + + protected SparseModelScorer(Weight weight, + List featureScorers) { + super(weight); + if (featureScorers.size() <= 1) { + throw new IllegalArgumentException( + "There must be at least 2 subScorers"); + } + this.subScorers = new DisiPriorityQueue(featureScorers.size()); + for (Scorer scorer : featureScorers) { + final DisiWrapper w = new DisiWrapper(scorer); + this.subScorers.add(w); + } + + itr = new ModelQuerySparseIterator(this.subScorers); + } + + @Override + public int docID() { + return itr.docID(); + } + + @Override + public float score() throws IOException { + DisiWrapper topList = subScorers.topList(); + // If target doc we wanted to advance to matches the actual doc + // the underlying features advanced to, perform the feature + // calculations, + // otherwise just continue with the model's scoring process with empty + // features. + reset(); + if (activeDoc == targetDoc) { + for (DisiWrapper w = topList; w != null; w = w.next) { + Scorer subScorer = w.scorer; + int featureId = ((FeatureWeight) subScorer.getWeight()).getId(); + allFeaturesUsed[featureId] = true; + allFeatureValues[featureId] = subScorer.score(); + } + } + normalize(); + return meta.score(modelFeatureValuesNormalized); + } + + @Override + public int freq() throws IOException { + DisiWrapper subMatches = subScorers.topList(); + int freq = 1; + for (DisiWrapper w = subMatches.next; w != null; w = w.next) { + freq += 1; + } + return freq; + } + + @Override + public DocIdSetIterator iterator() { + return itr; + } + + @Override + public final Collection getChildren() { + ArrayList children = new ArrayList<>(); + for (DisiWrapper scorer : subScorers) { + children.add(new ChildScorer(scorer.scorer, "SHOULD")); + } + return children; + } + + protected class ModelQuerySparseIterator extends + DisjunctionDISIApproximation { + + public ModelQuerySparseIterator(DisiPriorityQueue subIterators) { + super(subIterators); + } + + @Override + public final int nextDoc() throws IOException { + if (activeDoc == targetDoc) { + activeDoc = super.nextDoc(); + } else if (activeDoc < targetDoc) { + activeDoc = super.advance(targetDoc + 1); + } + return ++targetDoc; + } + + @Override + public final int advance(int target) throws IOException { + // If target doc we wanted to advance to matches the actual doc + // the underlying features advanced to, perform the feature + // calculations, + // otherwise just continue with the model's scoring process with + // empty features. + if (activeDoc < target) activeDoc = super.advance(target); + targetDoc = target; + return targetDoc; + } + } + + } + + public class DenseModelScorer extends Scorer { + int activeDoc = -1; // The doc that our scorer's are actually at + int targetDoc = -1; // The doc we were most recently told to go to + int freq = -1; + List featureScorers; + + protected DenseModelScorer(Weight weight, + List featureScorers) { + super(weight); + this.featureScorers = featureScorers; + } + + @Override + public int docID() { + return targetDoc; + } + + @Override + public float score() throws IOException { + reset(); + freq = 0; + if (targetDoc == activeDoc) { + for (Scorer scorer : featureScorers) { + if (scorer.docID() == activeDoc) { + freq++; + int featureId = ((FeatureWeight) scorer.getWeight()).getId(); + allFeaturesUsed[featureId] = true; + allFeatureValues[featureId] = scorer.score(); + } + } + } + normalize(); + return meta.score(modelFeatureValuesNormalized); + } + + @Override + public final Collection getChildren() { + ArrayList children = new ArrayList<>(); + for (Scorer scorer : featureScorers) { + children.add(new ChildScorer(scorer, "SHOULD")); + } + return children; + } + + @Override + public int freq() throws IOException { + return freq; + } + + @Override + public DocIdSetIterator iterator() { + return new DenseIterator(); + } + + class DenseIterator extends DocIdSetIterator { + + @Override + public int docID() { + return targetDoc; + } + + @Override + public int nextDoc() throws IOException { + if (activeDoc <= targetDoc) { + activeDoc = NO_MORE_DOCS; + for (Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); + } + } + } + return ++targetDoc; + } + + @Override + public int advance(int target) throws IOException { + if (activeDoc < target) { + activeDoc = NO_MORE_DOCS; + for (Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, + scorer.iterator().advance(target)); + } + } + } + targetDoc = target; + return target; + } + + @Override + public long cost() { + long sum = 0; + for (int i = 0; i < featureScorers.size(); i++) { + sum += featureScorers.get(i).iterator().cost(); + } + return sum; + } + + } + } + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/RankSVMModel.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/RankSVMModel.java new file mode 100644 index 000000000000..4af762d6afd3 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/RankSVMModel.java @@ -0,0 +1,90 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.solr.ltr.feature.ModelMetadata; +import org.apache.solr.ltr.util.ModelException; +import org.apache.solr.ltr.util.NamedParams; + +public class RankSVMModel extends ModelMetadata { + + protected float[] featureToWeight; + + public RankSVMModel(String name, String type, List features, + String featureStoreName, Collection allFeatures, + NamedParams params) throws ModelException { + super(name, type, features, featureStoreName, allFeatures, params); + + if (!hasParams()) { + throw new ModelException("Model " + name + " doesn't contain any weights"); + } + + Map modelWeights = (Map) getParams().get( + "weights"); + if (modelWeights == null || modelWeights.isEmpty()) { + throw new ModelException("Model " + name + " doesn't contain any weights"); + } + + // List features = getFeatures(); // model features + this.featureToWeight = new float[features.size()]; + + for (int i = 0; i < features.size(); ++i) { + String key = features.get(i).getName(); + if (!modelWeights.containsKey(key)) { + throw new ModelException("no weight for feature " + key); + } + featureToWeight[i] = modelWeights.get(key).floatValue(); + } + } + + @Override + public float score(float[] modelFeatureValuesNormalized) { + float score = 0; + for (int i = 0; i < modelFeatureValuesNormalized.length; ++i) { + score += modelFeatureValuesNormalized[i] * featureToWeight[i]; + } + return score; + } + + public Explanation explain(LeafReaderContext context, int doc, + float finalScore, List featureExplanations) { + List details = new ArrayList<>(); + int index = 0; + + for (Explanation featureExplain : featureExplanations) { + List featureDetails = new ArrayList<>(); + featureDetails.add(Explanation.match(featureToWeight[index], + "weight on feature [would be cool to have the name :)]")); + featureDetails.add(featureExplain); + + details.add(Explanation.match(featureExplain.getValue() + * featureToWeight[index], "prod of:", featureDetails)); + index++; + } + + return Explanation.match(finalScore, getName() + " [ " + getType() + + " ] model applied to features, sum of:", details); + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/package-info.java new file mode 100644 index 000000000000..32d1f22ec4cf --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/ranking/package-info.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + *

+ * This package contains the main logic for performing the reranking using + * a LTR model. + *

+ *

+ * A LTR model is plugged into the ranking through the {@link org.apache.solr.ltr.ranking.LTRQParserPlugin}, + * a {@link org.apache.solr.search.QParserPlugin}. The plugin will + * read from the request the model (instance of {@link org.apache.solr.ltr.ranking.ModelQuery}) + * used to perform the request plus other + * parameters. The plugin will generate a {@link org.apache.solr.ltr.ranking.LTRQuery}: + * a particular {@link org.apache.solr.search.RankQuery} + * that will encapsulate the given model and use it to + * rescore and rerank the document (by using an {@link org.apache.solr.ltr.ranking.LTRCollector}). + *

+ *

+ * A model will be applied on each document through a {@link org.apache.solr.ltr.ranking.ModelQuery}, a + * subclass of {@link org.apache.lucene.search.Query}. As a normal query, + * the learned model will produce a new score + * for each document reranked. + *

+ *

+ * A {@link org.apache.solr.ltr.ranking.ModelQuery} is created by providing an instance of + * {@link org.apache.solr.ltr.feature.ModelMetadata}. An instance of + * {@link org.apache.solr.ltr.feature.ModelMetadata} + * defines how to combine the features in order to create a new + * score for a document. A new learning to rank model is plugged + * into the framework by extending {@link org.apache.solr.ltr.feature.ModelMetadata}, + * (see for example {@link org.apache.solr.ltr.ranking.LambdaMARTModel} and {@link org.apache.solr.ltr.ranking.RankSVMModel}). + *

+ *

+ * The {@link org.apache.solr.ltr.ranking.ModelQuery} will take care of computing the values of + * all the features (see {@link org.apache.solr.ltr.ranking.Feature}) and then will delegate the final score + * generation to the {@link org.apache.solr.ltr.feature.ModelMetadata}, by calling the method + * {@link org.apache.solr.ltr.feature.ModelMetadata#score(float[] modelFeatureValuesNormalized) score(float[] modelFeatureValuesNormalized)}. + *

+ *

+ * Finally, a {@link org.apache.solr.ltr.ranking.Feature} will produce a particular value for each document, so + * it is modeled as a {@link org.apache.lucene.search.Query}. The package org.apache.solr.ltr.feature.impl contains several examples + * of features. One benefit of extending the Query object is that we can reuse + * Query as a feature, see for example {@link org.apache.solr.ltr.feature.impl.SolrFeature}. + * + * + * + * + * + */ +package org.apache.solr.ltr.ranking; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/ManagedFeatureStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/ManagedFeatureStore.java new file mode 100644 index 000000000000..7a634132ed69 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/ManagedFeatureStore.java @@ -0,0 +1,223 @@ +package org.apache.solr.ltr.rest; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.solr.common.SolrException; +import org.apache.solr.common.SolrException.ErrorCode; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.FeatureStore; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.InvalidFeatureNameException; +import org.apache.solr.ltr.util.NameValidator; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.rest.BaseSolrResource; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceStorage.StorageIO; +import org.apache.solr.rest.RestManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Managed resource for a storing a feature. + */ +public class ManagedFeatureStore extends ManagedResource implements + ManagedResource.ChildResourceSupport { + + private Map stores = new HashMap<>(); + + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + public static final String FEATURES_JSON_FIELD = "features"; + public static final String FEATURE_STORE_JSON_FIELD = "featureStores"; + public static final String DEFAULT_FSTORE = "_DEFAULT_"; + + public ManagedFeatureStore(String resourceId, SolrResourceLoader loader, + StorageIO storageIO) throws SolrException { + super(resourceId, loader, storageIO); + + } + + public synchronized FeatureStore getFeatureStore(String name) { + if (name == null) { + name = DEFAULT_FSTORE; + } + if (!stores.containsKey(name)) { + stores.put(name, new FeatureStore(name)); + } + return stores.get(name); + } + + @Override + protected void onManagedDataLoadedFromStorage(NamedList managedInitArgs, + Object managedData) throws SolrException { + + stores.clear(); + logger.info("------ managed feature ~ loading ------"); + if (managedData instanceof List) { + @SuppressWarnings("unchecked") + List> up = (List>) managedData; + for (Map u : up) { + update(u); + } + } + } + + public void update(Map map) { + String name = (String) map.get("name"); + String type = (String) map.get("type"); + String store = (String) map.get("store"); + + NamedParams params = null; + + if (map.containsKey("params")) { + @SuppressWarnings("unchecked") + Map np = (Map) map.get("params"); + params = new NamedParams(np); + } + + try { + + addFeature(name, type, store, params); + } catch (InvalidFeatureNameException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } catch (FeatureException e) { + logger.error(e.getMessage()); + e.printStackTrace(); + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } + } + + public synchronized void addFeature(String name, String type, + String featureStore, NamedParams params) + throws InvalidFeatureNameException, FeatureException { + if (featureStore == null) { + featureStore = DEFAULT_FSTORE; + } + + logger.info("register feature {} -> {} in store [" + featureStore + "]", + name, type); + if (!NameValidator.check(name)) { + throw new InvalidFeatureNameException(name); + } + + FeatureStore fstore = getFeatureStore(featureStore); + + if (fstore.containsFeature(name)) { + logger.error( + "feature {} yet contained in the store, please use a different name", + name); + throw new InvalidFeatureNameException(name + + " yet contained in the store"); + } + + if (params == null) { + params = NamedParams.EMPTY; + } + + Feature feature = createFeature(name, type, params, fstore.size()); + + fstore.add(feature); + } + + /** + * generates an instance this feature. + */ + private Feature createFeature(String name, String type, NamedParams params, + int id) throws FeatureException { + try { + Class c = Class.forName(type); + + Feature f = (Feature) c.newInstance(); + f.init(name, params, id); + return f; + + } catch (Exception e) { + throw new FeatureException(e.getMessage(), e); + } + } + + @SuppressWarnings("unchecked") + @Override + public Object applyUpdatesToManagedData(Object updates) { + if (updates instanceof List) { + List> up = (List>) updates; + for (Map u : up) { + update(u); + } + } + + if (updates instanceof Map) { + // a unique feature + update((Map) updates); + } + + // logger.info("fstore updated, features: "); + // for (String s : store.getFeatureNames()) { + // logger.info(" - {}", s); + // + // } + List features = new ArrayList<>(); + for (FeatureStore fs : stores.values()) { + features.addAll(fs.featuresAsManagedResources()); + } + return features; + } + + @Override + public void doDeleteChild(BaseSolrResource endpoint, String childId) { + if (childId.equals("*")) { + stores.clear(); + return; + } + if (stores.containsKey(childId)) { + stores.remove(childId); + } + } + + /** + * Called to retrieve a named part (the given childId) of the resource at the + * given endpoint. Note: since we have a unique child feature store we ignore + * the childId. + */ + @Override + public void doGet(BaseSolrResource endpoint, String childId) { + SolrQueryResponse response = endpoint.getSolrResponse(); + + // If no feature store specified, show all the feature stores available + if (childId == null) { + response.add(FEATURE_STORE_JSON_FIELD, stores.keySet()); + } else { + FeatureStore store = getFeatureStore(childId); + if (store == null) { + throw new SolrException(ErrorCode.BAD_REQUEST, + "missing feature store [" + childId + "]"); + } + response.add(FEATURES_JSON_FIELD, store.featuresAsManagedResources()); + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/ManagedModelStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/ManagedModelStore.java new file mode 100644 index 000000000000..897f98155c70 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/ManagedModelStore.java @@ -0,0 +1,304 @@ +package org.apache.solr.ltr.rest; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import org.apache.solr.common.SolrException; +import org.apache.solr.common.SolrException.ErrorCode; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.FeatureStore; +import org.apache.solr.ltr.feature.ModelMetadata; +import org.apache.solr.ltr.feature.ModelStore; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.feature.norm.impl.IdentityNormalizer; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.ModelException; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.ltr.util.NormalizerException; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.rest.BaseSolrResource; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceStorage.StorageIO; +import org.apache.solr.rest.RestManager; +import org.noggit.ObjectBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Menaged resource for storing a model + */ +public class ManagedModelStore extends ManagedResource implements + ManagedResource.ChildResourceSupport { + + ModelStore store; + private ManagedFeatureStore featureStores; + + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + private static final String MODELS_JSON_FIELD = "models"; + + public ManagedModelStore(String resourceId, SolrResourceLoader loader, + StorageIO storageIO) throws SolrException { + super(resourceId, loader, storageIO); + + store = new ModelStore(); + + } + + public void init(ManagedFeatureStore featureStores) { + logger.info("INIT model store"); + this.featureStores = featureStores; + } + + private Object managedData; + + @SuppressWarnings("unchecked") + @Override + protected void onManagedDataLoadedFromStorage(NamedList managedInitArgs, + Object managedData) throws SolrException { + store.clear(); + // the managed models on the disk or on zookeeper will be loaded in a lazy + // way, since we need to set the managed features first (unfortunately + // managed resources do not + // decouple the creation of a managed resource with the reading of the data + // from the storage) + this.managedData = managedData; + + } + + public void loadStoredModels() { + logger.info("------ managed models ~ loading ------"); + + if (managedData != null && managedData instanceof List) { + List> up = (List>) managedData; + for (Map u : up) { + try { + update(u); + } catch (ModelException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } + } + } + } + + public static Normalizer getNormalizerInstance(String type, NamedParams params) + throws NormalizerException { + Normalizer f; + Class c; + try { + c = Class.forName(type); + + f = (Normalizer) c.newInstance(); + } catch (ClassNotFoundException | InstantiationException + | IllegalAccessException e) { + throw new NormalizerException("missing normalizer " + type, e); + } + f.setType(type); + if (params == null) { + params = NamedParams.EMPTY; + } + f.init(params); + return f; + + } + + @SuppressWarnings("unchecked") + private Feature parseFeature(Map featureMap, + FeatureStore featureStore) throws NormalizerException, FeatureException, + CloneNotSupportedException { + // FIXME name shouldn't be be null, exception? + String name = (String) featureMap.get("name"); + + Normalizer norm = IdentityNormalizer.INSTANCE; + if (featureMap.containsKey("norm")) { + logger.info("adding normalizer {}", featureMap); + Map normMap = (Map) featureMap.get("norm"); + // FIXME type shouldn't be be null, exception? + String type = ((String) normMap.get("type")); + NamedParams params = null; + if (normMap.containsKey("params")) { + Object paramsObj = normMap.get("params"); + if (paramsObj != null) { + params = new NamedParams((Map) paramsObj); + } + } + norm = getNormalizerInstance(type, params); + } + if (featureStores == null) { + throw new FeatureException("missing feature store"); + } + + Feature meta = featureStore.get(name); + meta = (Feature) meta.clone(); + meta.setNorm(norm); + + return meta; + } + + @SuppressWarnings("unchecked") + public ModelMetadata makeModelMetaData(String json) throws ModelException { + Object parsedJson = null; + try { + parsedJson = ObjectBuilder.fromJSON(json); + } catch (IOException ioExc) { + throw new ModelException("ObjectBuilder failed parsing json", ioExc); + } + return makeModelMetaData((Map) parsedJson); + } + + @SuppressWarnings("unchecked") + public ModelMetadata makeModelMetaData(Map map) + throws ModelException { + String name = (String) map.get("name"); + Object o = map.get("store"); + String featureStoreName = (o == null) ? ManagedFeatureStore.DEFAULT_FSTORE + : (String) o; + NamedParams params = null; + FeatureStore fstore = featureStores.getFeatureStore(featureStoreName); + if (!map.containsKey("features")) { + throw new SolrException(ErrorCode.BAD_REQUEST, + "Missing mandatory field features"); + } + List featureList = (List) map.get("features"); + + List features = new ArrayList<>(); + + for (Object f : featureList) { + try { + Feature feature = parseFeature((Map) f, fstore); + if (!fstore.containsFeature(feature.getName())) { + throw new ModelException("missing feature " + feature.getName() + + " in model " + name); + } + features.add(feature); + } catch (NormalizerException | FeatureException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } catch (CloneNotSupportedException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } + } + + if (map.containsKey("params")) { + Map paramsMap = (Map) map.get("params"); + params = new NamedParams(paramsMap); + } + + String type = (String) map.get("type"); + ModelMetadata meta = null; + try { + Class cl = Class.forName(type); + Constructor cons = cl.getDeclaredConstructor(String.class, + String.class, List.class, String.class, Collection.class, + NamedParams.class); + meta = (ModelMetadata) cons.newInstance(name, type, features, + featureStoreName, fstore.getFeatures(), params); + } catch (Exception e) { + throw new ModelException("Model type does not exist " + type, e); + } + + return meta; + } + + @SuppressWarnings("unchecked") + private void update(Map map) throws ModelException { + + ModelMetadata meta = makeModelMetaData(map); + try { + addMetadataModel(meta); + } catch (ModelException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } + } + + @SuppressWarnings("unchecked") + @Override + protected Object applyUpdatesToManagedData(Object updates) { + if (updates instanceof List) { + List> up = (List>) updates; + for (Map u : up) { + try { + update(u); + } catch (ModelException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } + } + } + + if (updates instanceof Map) { + Map map = (Map) updates; + try { + update(map); + } catch (ModelException e) { + throw new SolrException(ErrorCode.BAD_REQUEST, e); + } + } + + return store.modelAsManagedResources(); + } + + @Override + public void doDeleteChild(BaseSolrResource endpoint, String childId) { + if (childId.equals("*")) store.clear(); + if (store.containsModel(childId)) store.delete(childId); + } + + /** + * Called to retrieve a named part (the given childId) of the resource at the + * given endpoint. Note: since we have a unique child managed store we ignore + * the childId. + */ + @Override + public void doGet(BaseSolrResource endpoint, String childId) { + + SolrQueryResponse response = endpoint.getSolrResponse(); + response.add(MODELS_JSON_FIELD, store.modelAsManagedResources()); + + } + + public synchronized void addMetadataModel(ModelMetadata modeldata) + throws ModelException { + logger.info("adding model {}", modeldata.getName()); + store.addModel(modeldata); + } + + public ModelMetadata getModel(String modelName) throws ModelException { + // this function replicates getModelStore().getModel(modelName), but + // it simplifies the testing (we can avoid to mock also a ModelStore). + return store.getModel(modelName); + } + + public ModelStore getModelStore() { + return store; + } + + @Override + public String toString() { + return "ManagedModelStore [store=" + store + ", featureStores=" + + featureStores + "]"; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/package-info.java new file mode 100644 index 000000000000..f61019e5f59d --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/rest/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains the {@link org.apache.solr.rest.ManagedResource} that encapsulate + * the feature and the model stores. + */ +package org.apache.solr.ltr.rest; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/FeatureException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/FeatureException.java new file mode 100644 index 000000000000..7bf57885572b --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/FeatureException.java @@ -0,0 +1,32 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +public class FeatureException extends LtrException { + + private static final long serialVersionUID = 1L; + + public FeatureException(String msg) { + super(msg); + } + + public FeatureException(String msg, Exception parent) { + super(msg, parent); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/InvalidFeatureNameException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/InvalidFeatureNameException.java new file mode 100644 index 000000000000..854428907ea7 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/InvalidFeatureNameException.java @@ -0,0 +1,32 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +public class InvalidFeatureNameException extends LtrException { + + private static final long serialVersionUID = 1L; + + public InvalidFeatureNameException(String featureName) { + super("Invalid feature name " + featureName); + } + + public InvalidFeatureNameException(String featureName, Exception parent) { + super("Invalid feature name " + featureName, parent); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/LtrException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/LtrException.java new file mode 100644 index 000000000000..0d6141518fa9 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/LtrException.java @@ -0,0 +1,34 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; + +public class LtrException extends IOException { + + private static final long serialVersionUID = 1L; + + public LtrException(String message) { + super(message); + } + + public LtrException(String message, Exception parent) { + super(message, parent); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/MacroExpander.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/MacroExpander.java new file mode 100644 index 000000000000..f88d4fc8e72e --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/MacroExpander.java @@ -0,0 +1,104 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.Map; + +// TODO This should be replaced with the MacroExpander inside Solr 5.2 +public class MacroExpander { + public static final String MACRO_START = "${"; + + private Map orig; + private String macroStart = MACRO_START; + private char escape = '\\'; + + public MacroExpander(Map orig) { + this.orig = orig; + } + + public static String expand(String val, Map params) { + MacroExpander mc = new MacroExpander(params); + return mc.expand(val); + } + + public String expand(String val) { + // quickest short circuit + int idx = val.indexOf(macroStart.charAt(0)); + if (idx < 0) return val; + + int start = 0; // start of the unprocessed part of the string + int end = 0; + StringBuilder sb = null; + for (;;) { + idx = val.indexOf(macroStart, idx); + int matchedStart = idx; + + // check if escaped + if (idx > 0) { + // check if escaped... + // TODO: what if you *want* to actually have a backslash... perhaps + // that's when we allow changing + // of the escape character? + + char ch = val.charAt(idx - 1); + if (ch == escape) { + idx += macroStart.length(); + continue; + } + } else if (idx < 0) { + if (sb == null) return val; + sb.append(val.substring(start)); + return sb.toString(); + } + + // found unescaped "${" + idx += macroStart.length(); + + int rbrace = val.indexOf('}', idx); + if (rbrace == -1) { + // no matching close brace... + continue; + } + + if (sb == null) { + sb = new StringBuilder(val.length() * 2); + } + + if (matchedStart > 0) { + sb.append(val.substring(start, matchedStart)); + } + + // update "start" to be at the end of ${...} + start = rbrace + 1; + + String paramName = val.substring(idx, rbrace); + + // in the event that expansions become context dependent... consult + // original? + String replacement = orig.get(paramName); + + // TODO - handle a list somehow... + if (replacement != null) { + sb.append(replacement); + } else { + sb.append(val.substring(matchedStart, start)); + } + + } + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/ModelException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/ModelException.java new file mode 100644 index 000000000000..47c5a6754f4d --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/ModelException.java @@ -0,0 +1,32 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +public class ModelException extends LtrException { + + private static final long serialVersionUID = 1L; + + public ModelException(String message) { + super(message); + } + + public ModelException(String message, Exception parent) { + super(message, parent); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NameValidator.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NameValidator.java new file mode 100644 index 000000000000..b97d5be1594d --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NameValidator.java @@ -0,0 +1,35 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class NameValidator { + private static Pattern pattern = Pattern + .compile("^[a-zA-Z0-9][a-zA-Z0-9_.\\-/\\(/\\)]*$"); + + public static boolean check(String name) { + if (name == null) { + return false; + } + Matcher matcher = pattern.matcher(name); + return matcher.find(); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NamedParams.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NamedParams.java new file mode 100644 index 000000000000..0446805c9946 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NamedParams.java @@ -0,0 +1,97 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class NamedParams extends HashMap { + + private static final long serialVersionUID = 1L; + public static final NamedParams EMPTY = new NamedParams(); + + public static float convertToFloat(Object o) { + float f = 0; + if (o instanceof Double) { + double d = (Double) o; + f = (float) d; + return f; + } + if (o instanceof Integer) { + int d = (Integer) o; + f = (float) d; + return f; + } + if (o instanceof Long) { + long l = (Long) o; + f = (float) l; + return f; + } + if (o instanceof Float) { + Float ff = (Float) o; + f = (float) ff; + return f; + } + + throw new NumberFormatException(o.getClass().getName() + + " cannot be converted to float"); + } + + public NamedParams() {} + + public NamedParams(Map params) { + for (Map.Entry p : params.entrySet()) { + add(p.getKey(), p.getValue()); + } + } + + @SuppressWarnings("unchecked") + public NamedParams(Object o) { + this((Map) o); + } + + public NamedParams add(String name, Object value) { + put(name, value); + return this; + } + + public double getDouble(String key, double defValue) { + if (containsKey(key)) { + return (double) get(key); + } + return defValue; + } + + public List getList(String key) { + if (containsKey(key)) { + return (List) get(key); + } + return null; + } + + public float getFloat(String key) { + Object o = get(key); + return convertToFloat(o); + } + + public float getFloat(String key, float value) { + return (containsKey(key)) ? getFloat(key) : value; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NormalizerException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NormalizerException.java new file mode 100644 index 000000000000..ee34ba794e3d --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/NormalizerException.java @@ -0,0 +1,32 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +public class NormalizerException extends LtrException { + + private static final long serialVersionUID = 1L; + + public NormalizerException(String msg) { + super(msg); + } + + public NormalizerException(String message, Exception parent) { + super(message, parent); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/package-info.java new file mode 100644 index 000000000000..55d1f4405c01 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/util/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Some utilities + */ +package org.apache.solr.ltr.util; diff --git a/solr/contrib/ltr/src/java/overview.html b/solr/contrib/ltr/src/java/overview.html new file mode 100644 index 000000000000..a04b97798e4c --- /dev/null +++ b/solr/contrib/ltr/src/java/overview.html @@ -0,0 +1,21 @@ + + + +Apache Solr Search Server: Learning to Rank Contrib + + diff --git a/solr/contrib/ltr/src/test-files/featureExamples/external_features.json b/solr/contrib/ltr/src/test-files/featureExamples/external_features.json new file mode 100644 index 000000000000..71ca99cac084 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/external_features.json @@ -0,0 +1,13 @@ +[ { + "name" : "matchedTitle", + "type" : "org.apache.solr.ltr.feature.impl.SolrFeature", + "params" : { + "q" : "{!terms f=title}${user_query}" + } +}, { + "name" : "titlePhraseMatch", + "type" : "org.apache.solr.ltr.feature.impl.SolrFeature", + "params" : { + "q" : "{!field f=title}${user_query}" + } +} ] \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/featureExamples/features-file-simple-params.json b/solr/contrib/ltr/src/test-files/featureExamples/features-file-simple-params.json new file mode 100644 index 000000000000..023f25e452f1 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/features-file-simple-params.json @@ -0,0 +1,12 @@ +[ + { + "name": "constant", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 100, + "complex":{ + "map":0.1 + } + } + } +] \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/featureExamples/features-file-simple.json b/solr/contrib/ltr/src/test-files/featureExamples/features-file-simple.json new file mode 100644 index 000000000000..601d3dd5f3b0 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/features-file-simple.json @@ -0,0 +1,6 @@ +[ +{ + "name":"constant"; + "type":"org.apache.solr.ltr.feature.impl.ValueFeature" +} +] \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/featureExamples/features-ranksvm.json b/solr/contrib/ltr/src/test-files/featureExamples/features-ranksvm.json new file mode 100644 index 000000000000..b3405fc6f091 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/features-ranksvm.json @@ -0,0 +1,51 @@ +[ + { + "name": "title", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 1 + } + }, + { + "name": "description", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 2 + } + }, + { + "name": "keywords", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 2 + } + }, + { + "name": "popularity", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 3 + } + }, + { + "name": "text", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 4 + } + }, + { + "name": "queryIntentPerson", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 5 + } + }, + { + "name": "queryIntentCompany", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 5 + } + } +] \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/featureExamples/features-store-test-model.json b/solr/contrib/ltr/src/test-files/featureExamples/features-store-test-model.json new file mode 100644 index 000000000000..4161c68d997a --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/features-store-test-model.json @@ -0,0 +1,51 @@ +[ + { + "name": "constant1", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "store":"test", + "params": { + "value": 1 + } + }, + { + "name": "constant2", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "store":"test", + "params": { + "value": 2 + } + }, + { + "name": "constant3", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "store":"test", + "params": { + "value": 3 + } + }, + { + "name": "constant4", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "store":"test", + "params": { + "value": 4 + } + }, + { + "name": "constant5", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "store":"test", + "params": { + "value": 5 + } + }, + { + "name": "pop", + "type": "org.apache.solr.ltr.feature.impl.FieldValueFeature", + "store":"test", + "params": { + "field": "popularity" + } + } + +] \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/featureExamples/features.json b/solr/contrib/ltr/src/test-files/featureExamples/features.json new file mode 100644 index 000000000000..624f6100d599 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/features.json @@ -0,0 +1,6 @@ +[ + +{ "name": "comp_industryTermScorer", "type": "org.apache.solr.ltr.feature.impl.SolrFeature", "params": {"q": "${user_query}","df": "comp_industry"}}, +{ "name": "comp_strongKeywordsTermScorer", "type": "org.apache.solr.ltr.feature.impl.SolrFeature", "params": {"q": "${user_query}","df": "comp_strongKeywords"}} + +] \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/featureExamples/lambdamart_features.json b/solr/contrib/ltr/src/test-files/featureExamples/lambdamart_features.json new file mode 100644 index 000000000000..6f6400a6aa2c --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/lambdamart_features.json @@ -0,0 +1,16 @@ +[ + { + "name": "matchedTitle", + "type": "org.apache.solr.ltr.feature.impl.SolrFeature", + "params": { + "q": "{!terms f=title}${user_query}" + } + }, + { + "name": "constantScoreToForceLambdaMARTScoreAllDocs", + "type": "org.apache.solr.ltr.feature.impl.ValueFeature", + "params": { + "value": 1 + } + } +] diff --git a/solr/contrib/ltr/src/test-files/log4j.properties b/solr/contrib/ltr/src/test-files/log4j.properties new file mode 100644 index 000000000000..d86c6988d5ed --- /dev/null +++ b/solr/contrib/ltr/src/test-files/log4j.properties @@ -0,0 +1,32 @@ +# Logging level +log4j.rootLogger=INFO, CONSOLE + +log4j.appender.CONSOLE=org.apache.log4j.ConsoleAppender +log4j.appender.CONSOLE.Target=System.err +log4j.appender.CONSOLE.layout=org.apache.log4j.EnhancedPatternLayout +log4j.appender.CONSOLE.layout.ConversionPattern=%-4r %-5p (%t) [%X{node_name} %X{collection} %X{shard} %X{replica} %X{core}] %c{1.} %m%n +log4j.logger.org.apache.zookeeper=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.apache.directory=WARN +log4j.logger.org.apache.solr.hadoop=INFO +log4j.logger.org.apache.solr.client.solrj.embedded.JettySolrRunner=DEBUG +org.apache.solr.client.solrj.embedded.JettySolrRunner=DEBUG + +#log4j.logger.org.apache.solr.update.processor.LogUpdateProcessor=DEBUG +#log4j.logger.org.apache.solr.update.processor.DistributedUpdateProcessor=DEBUG +#log4j.logger.org.apache.solr.update.PeerSync=DEBUG +#log4j.logger.org.apache.solr.core.CoreContainer=DEBUG +#log4j.logger.org.apache.solr.cloud.RecoveryStrategy=DEBUG +#log4j.logger.org.apache.solr.cloud.SyncStrategy=DEBUG +#log4j.logger.org.apache.solr.handler.admin.CoreAdminHandler=DEBUG +#log4j.logger.org.apache.solr.cloud.ZkController=DEBUG +#log4j.logger.org.apache.solr.update.DefaultSolrCoreState=DEBUG +#log4j.logger.org.apache.solr.common.cloud.ConnectionManager=DEBUG +#log4j.logger.org.apache.solr.update.UpdateLog=DEBUG +#log4j.logger.org.apache.solr.cloud.ChaosMonkey=DEBUG +#log4j.logger.org.apache.solr.update.TransactionLog=DEBUG +#log4j.logger.org.apache.solr.handler.ReplicationHandler=DEBUG +#log4j.logger.org.apache.solr.handler.IndexFetcher=DEBUG + +#log4j.logger.org.apache.solr.common.cloud.ClusterStateUtil=DEBUG +#log4j.logger.org.apache.solr.cloud.OverseerAutoReplicaFailoverThread=DEBUG diff --git a/solr/contrib/ltr/src/test-files/modelExamples/external_model.json b/solr/contrib/ltr/src/test-files/modelExamples/external_model.json new file mode 100644 index 000000000000..ca3a18666280 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/external_model.json @@ -0,0 +1,12 @@ +{ + "type":"org.apache.solr.ltr.ranking.RankSVMModel", + "name":"externalmodel", + "features":[ + { "name": "matchedTitle"} + ], + "params":{ + "weights": { + "matchedTitle": 0.999 + } + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model.json new file mode 100644 index 000000000000..91b1e047ea14 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model.json @@ -0,0 +1,38 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : 1, + "tree": { + "feature": "matchedTitle", + "threshold": 0.5, + "left" : { + "value" : -100 + }, + "right": { + "feature" : "this_feature_doesnt_exist", + "threshold": 10.0, + "left" : { + "value" : 50 + }, + "right" : { + "value" : 75 + } + } + } + }, + { + "weight" : 2, + "tree": { + "value" : -10 + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_feature.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_feature.json new file mode 100644 index 000000000000..24499890f58c --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_feature.json @@ -0,0 +1,24 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_feature", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : 1, + "tree": { + "threshold": 0.5, + "left" : { + "value" : -100 + }, + "right": { + "value" : 75 + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_features.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_features.json new file mode 100644 index 000000000000..e0164c4bc7f6 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_features.json @@ -0,0 +1,14 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_features", + "params":{ + "trees": [ + { + "weight" : 2, + "tree": { + "value" : -10 + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_left.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_left.json new file mode 100644 index 000000000000..3c3f89588b0b --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_left.json @@ -0,0 +1,22 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_left", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : 1, + "tree": { + "feature": "matchedTitle", + "threshold": 0.5, + "right": { + "value" : 75 + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_params.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_params.json new file mode 100644 index 000000000000..b11e0a8652f5 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_params.json @@ -0,0 +1,8 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_params", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ] +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_right.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_right.json new file mode 100644 index 000000000000..70f78ce8f6ac --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_right.json @@ -0,0 +1,22 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_right", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : 1, + "tree": { + "feature": "matchedTitle", + "threshold": 0.5, + "left" : { + "value" : -100 + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_threshold.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_threshold.json new file mode 100644 index 000000000000..3982e06f5199 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_threshold.json @@ -0,0 +1,24 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_threshold", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : 1, + "tree": { + "feature": "matchedTitle", + "left" : { + "value" : -100 + }, + "right": { + "value" : 75 + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_tree.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_tree.json new file mode 100644 index 000000000000..148e2f056eec --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_tree.json @@ -0,0 +1,15 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_tree", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : 2 + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_trees.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_trees.json new file mode 100644 index 000000000000..72d744965b58 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_trees.json @@ -0,0 +1,10 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_trees", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ], + "params":{ + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_weight.json b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_weight.json new file mode 100644 index 000000000000..bd5ffefb54f4 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/lambdamart_model_no_weight.json @@ -0,0 +1,24 @@ +{ + "type":"org.apache.solr.ltr.ranking.LambdaMARTModel", + "name":"lambdamartmodel_no_weight", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceLambdaMARTScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "tree": { + "feature": "matchedTitle", + "threshold": 0.5, + "left" : { + "value" : -100 + }, + "right": { + "value" : 75 + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/ranksvm-model.json b/solr/contrib/ltr/src/test-files/modelExamples/ranksvm-model.json new file mode 100644 index 000000000000..9e3c815a4ca5 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/ranksvm-model.json @@ -0,0 +1,30 @@ +{ + "type":"org.apache.solr.ltr.ranking.RankSVMModel", + "name":"6029760550880411648", + "features":[ + {"name":"title"}, + {"name":"description"}, + {"name":"keywords"}, + { + "name":"popularity", + "norm": { + "type":"org.apache.solr.ltr.feature.norm.impl.MinMaxNormalizer", + "params":{ "min":0.0, "max":10.0 } + } + }, + {"name":"text"}, + {"name":"queryIntentPerson"}, + {"name":"queryIntentCompany"} + ], + "params":{ + "weights": { + "title": 0.0000000000, + "description": 0.1000000000, + "keywords": 0.2000000000, + "popularity": 0.3000000000, + "text": 0.4000000000, + "queryIntentPerson":0.1231231, + "queryIntentCompany":0.12121211 + } + } +} \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/modelExamples/svm-model-normalized.json b/solr/contrib/ltr/src/test-files/modelExamples/svm-model-normalized.json new file mode 100644 index 000000000000..1aec25098531 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/svm-model-normalized.json @@ -0,0 +1,22 @@ +{ + "type":"org.apache.solr.ltr.ranking.RankSVMModel", + "name":"norm2", + "features":[ + { + "name":"feature2normalize", + "norm": { + "type":"org.apache.solr.ltr.feature.norm.impl.StandardNormalizer", + + "params":{ + "avg":0.0, + "std":2.0 + } + } + } + ], + "params":{ + "weights":{ + "feature2normalize":1.0 + } + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/svm-model.json b/solr/contrib/ltr/src/test-files/modelExamples/svm-model.json new file mode 100644 index 000000000000..f22f87ff0962 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/svm-model.json @@ -0,0 +1,20 @@ +{ + "type":"org.apache.solr.ltr.ranking.RankSVMModel", + "name":"svm", + "features":[ + {"name":"constant1"}, + {"name":"constant2"}, + {"name":"constant3"}, + {"name":"constant4"}, + {"name":"constant5"} + ], + "params":{ + "weights":{ + "constant1":1, + "constant2":2, + "constant3":3, + "constant4":4, + "constant5":5 + } + } +} \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/modelExamples/svm-model1.json b/solr/contrib/ltr/src/test-files/modelExamples/svm-model1.json new file mode 100644 index 000000000000..b24d32e8cb54 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/svm-model1.json @@ -0,0 +1,14 @@ +{ + "type":"org.apache.solr.ltr.ranking.RankSVMModel, + "name":"svm1", + "features":[ + {"name":"constant2"}, + {"name":"constant4"} + ], + "params":{ + "weights":{ + "constant2":3, + "constant4":6, + } + } +} \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/modelExamples/svm-sum-model.json b/solr/contrib/ltr/src/test-files/modelExamples/svm-sum-model.json new file mode 100644 index 000000000000..733e73886cb4 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/svm-sum-model.json @@ -0,0 +1,20 @@ +{ + "type":"org.apache.solr.ltr.ranking.RankSVMModel", + "name":"sum", + "features":[ + {"name":"constant1"}, + {"name":"constant2"}, + {"name":"constant3"}, + {"name":"constant4"}, + {"name":"constant5"} + ], + "params":{ + "weights":{ + "constant1":1, + "constant2":1, + "constant3":1, + "constant4":1, + "constant5":1 + } + } +} diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/indexSynonyms.txt b/solr/contrib/ltr/src/test-files/solr/collection1/conf/indexSynonyms.txt new file mode 100644 index 000000000000..af55e6efd779 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/indexSynonyms.txt @@ -0,0 +1,18 @@ +# the asf licenses this file to you under the apache license, version 2.0 +# (the "license"); you may not use this file except in compliance with +# the license. you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +#----------------------------------------------------------------------- + +# some synonym groups specific to this example +gb,gib,gigabyte,gigabytes +mb,mib,megabyte,megabytes +television, televisions, tv, tvs diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/protwords.txt b/solr/contrib/ltr/src/test-files/solr/collection1/conf/protwords.txt new file mode 100644 index 000000000000..02cb4ac23bdb --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/protwords.txt @@ -0,0 +1,20 @@ +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#----------------------------------------------------------------------- +# Use a protected word file to protect against the stemmer reducing two +# unrelated words to the same base word. + +# Some non-words that normally won't be encountered, +# just to test that they won't be stemmed. + +offical diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema-ltr.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema-ltr.xml new file mode 100644 index 000000000000..949250899e41 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema-ltr.xml @@ -0,0 +1,87 @@ + + + + + + + + + + + + + + + + + + id + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml new file mode 100644 index 000000000000..628a84bbc08d --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml @@ -0,0 +1,72 @@ + + + + + 6.0.0 + ${solr.data.dir:} + + + + + + + + + + + + + + + + + + + + + + 15000 + false + + + 1000 + + + ${solr.data.dir:} + + + + + + + + explicit + json + true + id + + + ltrComponent + + + + diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/stemdict.txt b/solr/contrib/ltr/src/test-files/solr/collection1/conf/stemdict.txt new file mode 100644 index 000000000000..78f05c223a85 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/stemdict.txt @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#----------------------------------------------------------------------- +# test that we can override the stemming algorithm with our own mappings +# these must be tab-separated +salty salt diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/stopwords.txt b/solr/contrib/ltr/src/test-files/solr/collection1/conf/stopwords.txt new file mode 100644 index 000000000000..eabae3b7c0dd --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/stopwords.txt @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +a diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/synonyms.txt b/solr/contrib/ltr/src/test-files/solr/collection1/conf/synonyms.txt new file mode 100644 index 000000000000..0ef0e8daabaf --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/synonyms.txt @@ -0,0 +1,28 @@ +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#----------------------------------------------------------------------- +#some test synonym mappings unlikely to appear in real input text +aaafoo => aaabar +bbbfoo => bbbfoo bbbbar +cccfoo => cccbar cccbaz +fooaaa,baraaa,bazaaa + +# Some synonym groups specific to this example +GB,gib,gigabyte,gigabytes +MB,mib,megabyte,megabytes +Television, Televisions, TV, TVs +#notice we use "gib" instead of "GiB" so any WordDelimiterFilter coming +#after us won't split it into two words. + +# Synonym mappings can be used for spelling correction too +pixima => pixma diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/wdfftypes.txt b/solr/contrib/ltr/src/test-files/solr/collection1/conf/wdfftypes.txt new file mode 100644 index 000000000000..52e60f8643f2 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/wdfftypes.txt @@ -0,0 +1,11 @@ +# A customized type mapping for WordDelimiterFilterFactory +# the allowable types are: LOWER, UPPER, ALPHA, DIGIT, ALPHANUM, SUBWORD_DELIM +# +# the default for any character without a mapping is always computed from +# Unicode character properties + +# Map the $, % characters to DIGIT +# This might be useful for financial data. +$ => DIGIT +% => DIGIT +& => ALPHA diff --git a/solr/contrib/ltr/src/test-files/solr/solr.xml b/solr/contrib/ltr/src/test-files/solr/solr.xml new file mode 100644 index 000000000000..c8c3ebeb30a5 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/solr.xml @@ -0,0 +1,42 @@ + + + + + + ${shareSchema:false} + ${configSetBaseDir:configsets} + ${coreRootDirectory:.} + + + ${urlScheme:} + ${socketTimeout:90000} + ${connTimeout:15000} + + + + 127.0.0.1 + ${hostPort:8983} + ${hostContext:solr} + ${solr.zkclienttimeout:30000} + ${genericCoreNodeNames:true} + ${leaderVoteWait:10000} + ${distribUpdateConnTimeout:45000} + ${distribUpdateSoTimeout:340000} + + + diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java new file mode 100644 index 000000000000..c4b92070405b --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java @@ -0,0 +1,422 @@ +package org.apache.solr.ltr; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.net.URL; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Scanner; +import java.util.SortedMap; +import java.util.TreeMap; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang.StringUtils; +import org.apache.solr.SolrTestCaseJ4.SuppressSSL; +import org.apache.solr.common.params.CommonParams; +import org.apache.solr.common.util.ContentStream; +import org.apache.solr.common.util.ContentStreamBase; +import org.apache.solr.ltr.feature.ModelMetadata; +import org.apache.solr.ltr.feature.impl.ValueFeature; +import org.apache.solr.ltr.feature.impl.ValueFeature.ValueFeatureWeight; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.ranking.LTRComponent.LTRParams; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.apache.solr.ltr.rest.ManagedFeatureStore; +import org.apache.solr.ltr.rest.ManagedModelStore; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.ModelException; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.request.SolrQueryRequestBase; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.util.RestTestBase; +import org.eclipse.jetty.servlet.ServletHolder; +import org.noggit.ObjectBuilder; +import org.restlet.ext.servlet.ServerServlet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressSSL +public class TestRerankBase extends RestTestBase { + + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + protected static File tmpSolrHome; + protected static File tmpConfDir; + + public static final String FEATURE_ENDPOINT = LTRParams.FSTORE_END_POINT; + public static final String MODEL_ENDPOINT = LTRParams.MSTORE_END_POINT; + public static final String FEATURE_FILE_NAME = "_schema_fstore.json"; + public static final String MODEL_FILE_NAME = "_schema_mstore.json"; + public static final String PARENT_ENDPOINT = "/schema/*"; + + protected static final String collection = "collection1"; + protected static final String confDir = collection + "/conf"; + + protected static File fstorefile = null; + protected static File mstorefile = null; + + public static void setuptest() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + bulkIndex(); + } + + public static void setupPersistenttest() throws Exception { + setupPersistentTest("solrconfig-ltr.xml", "schema-ltr.xml"); + bulkIndex(); + } + + // NOTE: this will return a new rest manager since the getCore() method + // returns a new instance of a restManager. + public static ManagedFeatureStore getNewManagedFeatureStore() { + ManagedFeatureStore fs = (ManagedFeatureStore) h.getCore().getRestManager() + .getManagedResource(FEATURE_ENDPOINT); + return fs; + } + + public static ManagedModelStore getNewManagedModelStore() { + + ManagedModelStore fs = (ManagedModelStore) h.getCore().getRestManager() + .getManagedResource(MODEL_ENDPOINT); + return fs; + } + + public static void setuptest(String solrconfig, String schema) + throws Exception { + initCore(solrconfig, schema); + + tmpSolrHome = createTempDir().toFile(); + tmpConfDir = new File(tmpSolrHome, confDir); + tmpConfDir.deleteOnExit(); + FileUtils.copyDirectory(new File(TEST_HOME()), + tmpSolrHome.getAbsoluteFile()); + File fstore = new File(tmpConfDir, FEATURE_FILE_NAME); + File mstore = new File(tmpConfDir, MODEL_FILE_NAME); + + if (fstore.exists()) { + logger.info("remove feature store config file in {}", + fstore.getAbsolutePath()); + Files.delete(fstore.toPath()); + } + if (mstore.exists()) { + logger.info("remove model store config file in {}", + mstore.getAbsolutePath()); + Files.delete(mstore.toPath()); + } + if (!solrconfig.equals("solrconfig.xml")) FileUtils.copyFile(new File( + tmpSolrHome.getAbsolutePath() + "/collection1/conf/" + solrconfig), + new File(tmpSolrHome.getAbsolutePath() + + "/collection1/conf/solrconfig.xml")); + if (!schema.equals("schema.xml")) FileUtils + .copyFile(new File(tmpSolrHome.getAbsolutePath() + "/collection1/conf/" + + schema), new File(tmpSolrHome.getAbsolutePath() + + "/collection1/conf/schema.xml")); + + final SortedMap extraServlets = new TreeMap<>(); + final ServletHolder solrRestApi = new ServletHolder("SolrSchemaRestApi", + ServerServlet.class); + solrRestApi.setInitParameter("org.restlet.application", + "org.apache.solr.rest.SolrSchemaRestApi"); + solrRestApi.setInitParameter("storageIO", + "org.apache.solr.rest.ManagedResourceStorage$InMemoryStorageIO"); + extraServlets.put(solrRestApi, PARENT_ENDPOINT); + + System.setProperty("managed.schema.mutable", "true"); + System.setProperty("enable.update.log", "false"); + + createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema, + "/solr", true, extraServlets); + } + + public static void setupPersistentTest(String solrconfig, String schema) + throws Exception { + initCore(solrconfig, schema); + + tmpSolrHome = createTempDir().toFile(); + tmpConfDir = new File(tmpSolrHome, confDir); + tmpConfDir.deleteOnExit(); + FileUtils.copyDirectory(new File(TEST_HOME()), + tmpSolrHome.getAbsoluteFile()); + fstorefile = new File(tmpConfDir, FEATURE_FILE_NAME); + mstorefile = new File(tmpConfDir, MODEL_FILE_NAME); + + if (fstorefile.exists()) { + logger.info("remove feature store config file in {}", + fstorefile.getAbsolutePath()); + Files.delete(fstorefile.toPath()); + } + if (mstorefile.exists()) { + logger.info("remove model store config file in {}", + mstorefile.getAbsolutePath()); + Files.delete(mstorefile.toPath()); + } + // clearModelStore(); + + final SortedMap extraServlets = new TreeMap<>(); + final ServletHolder solrRestApi = new ServletHolder("SolrSchemaRestApi", + ServerServlet.class); + solrRestApi.setInitParameter("org.restlet.application", + "org.apache.solr.rest.SolrSchemaRestApi"); + solrRestApi.setInitParameter("storageIO", + "org.apache.solr.rest.ManagedResourceStorage$JsonStorageIO"); + + extraServlets.put(solrRestApi, PARENT_ENDPOINT); // '/schema/*' matches + // '/schema', + // '/schema/', and + // '/schema/whatever...' + + System.setProperty("managed.schema.mutable", "true"); + // System.setProperty("enable.update.log", "false"); + + createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema, + "/solr", true, extraServlets); + } + + protected static void aftertest() throws Exception { + + jetty.stop(); + jetty = null; + FileUtils.deleteDirectory(tmpSolrHome); + System.clearProperty("managed.schema.mutable"); + // System.clearProperty("enable.update.log"); + + restTestHarness = null; + } + + public static void makeRestTestHarnessNull() { + restTestHarness = null; + } + + /** produces a model encoded in json **/ + public static String getModelInJson(String name, String type, + String[] features, String fstore, String params) { + StringBuilder sb = new StringBuilder(); + sb.append("{\n"); + sb.append("\"name\":").append('"').append(name).append('"').append(",\n"); + sb.append("\"store\":").append('"').append(fstore).append('"') + .append(",\n"); + sb.append("\"type\":").append('"').append(type).append('"').append(",\n"); + sb.append("\"features\":").append('['); + for (String feature : features) { + sb.append("\n\t{ "); + sb.append("\"name\":").append('"').append(feature).append('"') + .append("},"); + } + sb.deleteCharAt(sb.length() - 1); + sb.append("\n]\n"); + if (params != null) { + sb.append(",\n"); + sb.append("\"params\":").append(params); + } + sb.append("\n}\n"); + return sb.toString(); + } + + /** produces a model encoded in json **/ + public static String getFeatureInJson(String name, String type, + String fstore, String params) { + StringBuilder sb = new StringBuilder(); + sb.append("{\n"); + sb.append("\"name\":").append('"').append(name).append('"').append(",\n"); + sb.append("\"store\":").append('"').append(fstore).append('"') + .append(",\n"); + sb.append("\"type\":").append('"').append(type).append('"'); + if (params != null) { + sb.append(",\n"); + sb.append("\"params\":").append(params); + } + sb.append("\n}\n"); + return sb.toString(); + } + + protected static void loadFeature(String name, String type, String params) + throws Exception { + String feature = getFeatureInJson(name, type, "test", params); + logger.info("loading feauture \n{} ", feature); + assertJPut(FEATURE_ENDPOINT, feature, "/responseHeader/status==0"); + } + + protected static void loadFeature(String name, String type, String fstore, + String params) throws Exception { + String feature = getFeatureInJson(name, type, fstore, params); + logger.info("loading feauture \n{} ", feature); + assertJPut(FEATURE_ENDPOINT, feature, "/responseHeader/status==0"); + } + + protected static void loadModel(String name, String type, String[] features, + String params) throws Exception { + loadModel(name, type, features, "test", params); + } + + protected static void loadModel(String name, String type, String[] features, + String fstore, String params) throws Exception { + String model = getModelInJson(name, type, features, fstore, params); + logger.info("loading model \n{} ", model); + assertJPut(MODEL_ENDPOINT, model, "/responseHeader/status==0"); + } + + public static void loadModels(String fileName) throws Exception { + URL url = TestRerankBase.class.getResource("/modelExamples/" + fileName); + String multipleModels = FileUtils.readFileToString(new File(url.toURI()), "UTF-8"); + + assertJPut(MODEL_ENDPOINT, multipleModels, "/responseHeader/status==0"); + } + + public static void createModelFromFiles(String modelFileName, + String featureFileName) throws ModelException, Exception { + URL url = TestRerankBase.class.getResource("/modelExamples/" + + modelFileName); + String modelJson = FileUtils.readFileToString(new File(url.toURI()), "UTF-8"); + ManagedModelStore ms = getNewManagedModelStore(); + + url = TestRerankBase.class.getResource("/featureExamples/" + + featureFileName); + String featureJson = FileUtils.readFileToString(new File(url.toURI()),"UTF-8"); + + Object parsedFeatureJson = null; + try { + parsedFeatureJson = ObjectBuilder.fromJSON(featureJson); + } catch (IOException ioExc) { + throw new ModelException("ObjectBuilder failed parsing json", ioExc); + } + + ManagedFeatureStore fs = getNewManagedFeatureStore(); + // fs.getFeatureStore(null).clear(); + fs.doDeleteChild(null, "*"); // is this safe?? + // based on my need to call this I dont think that + // "getNewManagedFeatureStore()" + // is actually returning a new feature store each time + fs.applyUpdatesToManagedData(parsedFeatureJson); + ms.init(fs); + + ModelMetadata meta = ms.makeModelMetaData(modelJson); + ms.addMetadataModel(meta); + } + + public static void loadFeatures(String fileName) throws Exception { + URL url = TestRerankBase.class.getResource("/featureExamples/" + fileName); + String multipleFeatures = FileUtils.readFileToString(new File(url.toURI()),"UTF-8"); + logger.info("send \n{}", multipleFeatures); + + assertJPut(FEATURE_ENDPOINT, multipleFeatures, "/responseHeader/status==0"); + } + + protected List getFeatures(List names) + throws FeatureException { + List features = new ArrayList<>(); + int pos = 0; + for (String name : names) { + ValueFeature f = new ValueFeature(); + f.init(name, new NamedParams().add("value", 10), pos); + features.add(f); + ++pos; + } + return features; + } + + protected List getFeatures(String[] names) throws FeatureException { + return getFeatures(Arrays.asList(names)); + } + + protected static void loadModelAndFeatures(String name, int allFeatureCount, + int modelFeatureCount) throws Exception { + String[] features = new String[modelFeatureCount]; + String[] weights = new String[modelFeatureCount]; + for (int i = 0; i < allFeatureCount; i++) { + String featureName = "c" + i; + if (i < modelFeatureCount) { + features[i] = featureName; + weights[i] = "\"" + featureName + "\":1.0"; + } + loadFeature(featureName, ValueFeatureWeight.class.getCanonicalName(), + "{\"value\":" + i + "}"); + } + + loadModel(name, RankSVMModel.class.getCanonicalName(), features, + "{\"weights\":{" + StringUtils.join(weights, ",") + "}}"); + } + + protected static void bulkIndex() throws Exception { + System.out.println("-----------index ---------------------"); + assertU(adoc("title", "bloomberg different bla", "description", + "bloomberg", "id", "6", "popularity", "1")); + assertU(adoc("title", "bloomberg bloomberg ", "description", "bloomberg", + "id", "7", "popularity", "2")); + assertU(adoc("title", "bloomberg bloomberg bloomberg", "description", + "bloomberg", "id", "8", "popularity", "3")); + assertU(adoc("title", "bloomberg bloomberg bloomberg bloomberg", + "description", "bloomberg", "id", "9", "popularity", "5")); + assertU(commit()); + } + + protected static void bulkIndex(String filePath) throws Exception { + SolrQueryRequestBase req = lrf.makeRequest(CommonParams.STREAM_CONTENTTYPE, + "application/xml"); + + List streams = new ArrayList(); + File file = new File(filePath); + streams.add(new ContentStreamBase.FileStream(file)); + req.setContentStreams(streams); + + try { + SolrQueryResponse res = new SolrQueryResponse(); + h.updater.handleRequest(req, res); + } catch (Throwable ex) { + // Ignore. Just log the exception and go to the next file + logger.error(ex.getMessage()); + ex.printStackTrace(); + } + assertU(commit()); + + } + + protected static void buildIndexUsingAdoc(String filepath) + throws FileNotFoundException { + Scanner scn = new Scanner(new File(filepath),"UTF-8"); + StringBuffer buff = new StringBuffer(); + scn.nextLine(); + scn.nextLine(); + scn.nextLine(); // Skip the first 3 lines then add everything else + ArrayList docsToAdd = new ArrayList(); + while (scn.hasNext()) { + String curLine = scn.nextLine(); + if (curLine.contains("")) { + buff.append(curLine + "\n"); + docsToAdd.add(buff.toString().replace("", "") + .replace("", "\n") + .replace("", "\n")); + if (!scn.hasNext()) break; + else curLine = scn.nextLine(); + buff = new StringBuffer(); + } + buff.append(curLine + "\n"); + } + for (String doc : docsToAdd) { + assertU(doc.trim()); + } + assertU(commit()); + scn.close(); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLogging.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLogging.java new file mode 100644 index 000000000000..337400da0cef --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLogging.java @@ -0,0 +1,182 @@ +package org.apache.solr.ltr.feature; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.SolrTestCaseJ4.SuppressSSL; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.impl.FieldValueFeature; +import org.apache.solr.ltr.feature.impl.SolrFeature; +import org.apache.solr.ltr.feature.impl.ValueFeature; +import org.apache.solr.ltr.ranking.LTRComponent; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +@SuppressSSL +public class TestFeatureLogging extends TestRerankBase { + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testGeneratedFeatures() throws Exception { + loadFeature("c1", ValueFeature.class.getCanonicalName(), "test1", + "{\"value\":1.0}"); + loadFeature("c2", ValueFeature.class.getCanonicalName(), "test1", + "{\"value\":2.0}"); + loadFeature("c3", ValueFeature.class.getCanonicalName(), "test1", + "{\"value\":3.0}"); + loadFeature("pop", FieldValueFeature.class.getCanonicalName(), "test1", + "{\"field\":\"popularity\"}"); + loadFeature("nomatch", SolrFeature.class.getCanonicalName(), "test1", + "{\"q\":\"{!terms f=title}foobarbat\"}"); + loadFeature("yesmatch", SolrFeature.class.getCanonicalName(), "test1", + "{\"q\":\"{!terms f=popularity}2\"}"); + + loadModel("sum1", RankSVMModel.class.getCanonicalName(), new String[] { + "c1", "c2", "c3"}, "test1", + "{\"weights\":{\"c1\":1.0,\"c2\":1.0,\"c3\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.add("fl", "title,description,id,popularity,[fv]"); + query.add("rows", "3"); + query.add("debugQuery", "on"); + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr reRankDocs=3 model=sum1}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'title':'bloomberg bloomberg ', 'description':'bloomberg','id':'7', 'popularity':2, '[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0;yesmatch:1.0'}"); + + query.remove("fl"); + query.add("fl", "[fv]"); + query.add("rows", "3"); + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr reRankDocs=3 model=sum1}"); + + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0;yesmatch:1.0'}"); + } + + @Test + public void testGeneratedOnlyFeatures() throws Exception { + loadFeature("c1", ValueFeature.class.getCanonicalName(), "test3", + "{\"value\":1.0}"); + loadFeature("c2", ValueFeature.class.getCanonicalName(), "test3", + "{\"value\":2.0}"); + loadFeature("c3", ValueFeature.class.getCanonicalName(), "test3", + "{\"value\":3.0}"); + loadFeature("pop", FieldValueFeature.class.getCanonicalName(), "test3", + "{\"field\":\"popularity\"}"); + + loadModel("sumonly", RankSVMModel.class.getCanonicalName(), new String[] { + "c1", "c2", "c3"}, "test3", + "{\"weights\":{\"c1\":1.0,\"c2\":1.0,\"c3\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.add("fl", "title,description,id,popularity,[fv]"); + query.add("rows", "3"); + query.add("debugQuery", "on"); + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr reRankDocs=3 model=sumonly}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'title':'bloomberg bloomberg ', 'description':'bloomberg','id':'7', 'popularity':2, '[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0'}"); + + query.remove("fl"); + query.add("fl", "fv:[fv]"); + query.add("rows", "3"); + + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr reRankDocs=3 model=sumonly}"); + + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'fv':'c1:1.0;c2:2.0;c3:3.0;pop:2.0'}"); + + } + + @Test + public void testGeneratedGroup() throws Exception { + loadFeature("c1", ValueFeature.class.getCanonicalName(), "testgroup", + "{\"value\":1.0}"); + loadFeature("c2", ValueFeature.class.getCanonicalName(), "testgroup", + "{\"value\":2.0}"); + loadFeature("c3", ValueFeature.class.getCanonicalName(), "testgroup", + "{\"value\":3.0}"); + loadFeature("pop", FieldValueFeature.class.getCanonicalName(), "testgroup", + "{\"field\":\"popularity\"}"); + + loadModel("sumgroup", RankSVMModel.class.getCanonicalName(), new String[] { + "c1", "c2", "c3"}, "testgroup", + "{\"weights\":{\"c1\":1.0,\"c2\":1.0,\"c3\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.add("fl", "*,[fv]"); + query.add("debugQuery", "on"); + + query.remove("fl"); + query.add("fl", "fv:[fv]"); + query.add("rows", "3"); + query.add("group", "true"); + query.add("group.field", "title"); + + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr reRankDocs=3 model=sumgroup}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ( + "/query" + query.toQueryString(), + "/grouped/title/groups/[0]/doclist/docs/[0]/=={'fv':'c1:1.0;c2:2.0;c3:3.0;pop:5.0'}"); + + query.add("fvwt", "json"); + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ( + "/query" + query.toQueryString(), + "/grouped/title/groups/[0]/doclist/docs/[0]/fv/=={'c1':1.0,'c2':2.0,'c3':3.0,'pop':5.0}"); + query.remove("fl"); + query.add("fl", "fv:[fv]"); + + assertJQ( + "/query" + query.toQueryString(), + "/grouped/title/groups/[0]/doclist/docs/[0]/fv/=={'c3':3.0,'pop':5.0,'c1':1.0,'c2':2.0}"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureMetadata.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureMetadata.java new file mode 100644 index 000000000000..5590553ec325 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureMetadata.java @@ -0,0 +1,76 @@ +package org.apache.solr.ltr.feature; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.impl.OriginalScoreFeature; +import org.apache.solr.ltr.feature.impl.ValueFeature; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.rest.ManagedFeatureStore; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.InvalidFeatureNameException; +import org.apache.solr.ltr.util.NamedParams; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestFeatureMetadata extends TestRerankBase { + + static ManagedFeatureStore store = null; + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + store = getNewManagedFeatureStore(); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void getInstanceTest() throws FeatureException, + InvalidFeatureNameException { + + store.addFeature("test", OriginalScoreFeature.class.getCanonicalName(), + "testFstore", NamedParams.EMPTY); + Feature feature = store.getFeatureStore("testFstore").get("test"); + assertEquals("test", feature.getName()); + assertEquals(OriginalScoreFeature.class.getCanonicalName(), feature + .getClass().getCanonicalName()); + } + + @Test(expected = FeatureException.class) + public void getInvalidInstanceTest() throws FeatureException, + InvalidFeatureNameException { + store.addFeature("test", "org.apache.solr.ltr.feature.LOLFeature", + "testFstore2", NamedParams.EMPTY); + + } + + @Test(expected = InvalidFeatureNameException.class) + public void getInvalidNameTest() throws FeatureException, + InvalidFeatureNameException { + + store.addFeature("!!!??????????", ValueFeature.class.getCanonicalName(), + "testFstore3", NamedParams.EMPTY); + + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureStore.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureStore.java new file mode 100644 index 000000000000..068fa4a8f8d0 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureStore.java @@ -0,0 +1,103 @@ +package org.apache.solr.ltr.feature; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.impl.FieldValueFeature; +import org.apache.solr.ltr.feature.impl.OriginalScoreFeature; +import org.apache.solr.ltr.feature.impl.ValueFeature; +import org.apache.solr.ltr.ranking.Feature; +import org.apache.solr.ltr.rest.ManagedFeatureStore; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.InvalidFeatureNameException; +import org.apache.solr.ltr.util.NamedParams; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestFeatureStore extends TestRerankBase { + + static ManagedFeatureStore fstore = null; + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + fstore = getNewManagedFeatureStore(); + } + + @Test + public void testFeatureStoreAdd() throws InvalidFeatureNameException, + FeatureException { + FeatureStore fs = fstore.getFeatureStore("fstore-testFeature"); + for (int i = 0; i < 5; i++) { + fstore.addFeature("c" + i, OriginalScoreFeature.class.getCanonicalName(), + "fstore-testFeature", NamedParams.EMPTY); + + assertTrue(fs.containsFeature("c" + i)); + + } + assertEquals(5, fs.size()); + + } + + @Test + public void testFeatureStoreGet() throws FeatureException, + InvalidFeatureNameException { + FeatureStore fs = fstore.getFeatureStore("fstore-testFeature2"); + for (int i = 0; i < 5; i++) { + + fstore.addFeature("c" + (float) i, ValueFeature.class.getCanonicalName(), + "fstore-testFeature2", new NamedParams().add("value", i)); + + } + + for (float i = 0; i < 5; i++) { + Feature f = fs.get("c" + (float) i); + assertEquals("c" + i, f.getName()); + assertEquals(i, f.getParams().getFloat("value"), 0.0001); + } + } + + @Test(expected = FeatureException.class) + public void testMissingFeature() throws InvalidFeatureNameException, + FeatureException { + FeatureStore fs = fstore.getFeatureStore("fstore-testFeature3"); + for (int i = 0; i < 5; i++) { + fstore.addFeature("testc" + (float) i, + ValueFeature.class.getCanonicalName(), "fstore-testFeature3", + new NamedParams().add("value", i)); + + } + fs.get("missing_feature_name"); + } + + @Test(expected = FeatureException.class) + public void testMissingFeature2() throws InvalidFeatureNameException, + FeatureException { + FeatureStore fs = fstore.getFeatureStore("fstore-testFeature4"); + for (int i = 0; i < 5; i++) { + fstore.addFeature("testc" + (float) i, + ValueFeature.class.getCanonicalName(), "fstore-testFeature4", + new NamedParams().add("value", i)); + + } + fstore.addFeature("invalidparam", + FieldValueFeature.class.getCanonicalName(), "fstore-testFeature4", + NamedParams.EMPTY); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestModelMetadata.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestModelMetadata.java new file mode 100644 index 000000000000..2575981a1567 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestModelMetadata.java @@ -0,0 +1,154 @@ +package org.apache.solr.ltr.feature; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.HashMap; +import java.util.Map; + +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.apache.solr.ltr.rest.ManagedModelStore; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.ModelException; +import org.apache.solr.ltr.util.NamedParams; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestModelMetadata extends TestRerankBase { + + static ManagedModelStore store = null; + static FeatureStore fstore = null; + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + // loadFeatures("features-store-test-model.json"); + store = getNewManagedModelStore(); + fstore = getNewManagedFeatureStore().getFeatureStore("test"); + + } + + @Test + public void getInstanceTest() throws FeatureException, ModelException { + Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5", 1d); + + ModelMetadata meta = new RankSVMModel("test1", + RankSVMModel.class.getCanonicalName(), getFeatures(new String[] { + "constant1", "constant5"}), "test", fstore.getFeatures(), + new NamedParams().add("weights", weights)); + + store.addMetadataModel(meta); + ModelMetadata m = store.getModel("test1"); + assertEquals(meta, m); + } + + @Test(expected = ModelException.class) + public void getInvalidTypeTest() throws ModelException, FeatureException { + ModelMetadata meta = new RankSVMModel("test2", + "org.apache.solr.ltr.model.LOLModel", getFeatures(new String[] { + "constant1", "constant5"}), "test", fstore.getFeatures(), null); + store.addMetadataModel(meta); + ModelMetadata m = store.getModel("test38290156821076"); + } + + @Test(expected = ModelException.class) + public void getInvalidNameTest() throws ModelException, FeatureException { + ModelMetadata meta = new RankSVMModel("!!!??????????", + RankSVMModel.class.getCanonicalName(), getFeatures(new String[] { + "constant1", "constant5"}), "test", fstore.getFeatures(), null); + store.addMetadataModel(meta); + store.getModel("!!!??????????"); + } + + @Test(expected = ModelException.class) + public void existingNameTest() throws ModelException, FeatureException { + Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5", 1d); + + ModelMetadata meta = new RankSVMModel("test3", + RankSVMModel.class.getCanonicalName(), getFeatures(new String[] { + "constant1", "constant5"}), "test", fstore.getFeatures(), + new NamedParams().add("weights", weights)); + store.addMetadataModel(meta); + ModelMetadata m = store.getModel("test3"); + assertEquals(meta, m); + store.addMetadataModel(meta); + } + + @Test(expected = ModelException.class) + public void duplicateFeatureTest() throws ModelException, FeatureException { + Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5", 1d); + + ModelMetadata meta = new RankSVMModel("test4", + RankSVMModel.class.getCanonicalName(), getFeatures(new String[] { + "constant1", "constant1"}), "test", fstore.getFeatures(), + new NamedParams().add("weights", weights)); + store.addMetadataModel(meta); + + } + + @Test(expected = ModelException.class) + public void missingFeatureTest() throws ModelException, FeatureException { + Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5missing", 1d); + + ModelMetadata meta = new RankSVMModel("test5", + RankSVMModel.class.getCanonicalName(), getFeatures(new String[] { + "constant1", "constant1"}), "test", fstore.getFeatures(), + new NamedParams().add("weights", weights)); + store.addMetadataModel(meta); + + } + + @Test(expected = ModelException.class) + public void notExistingClassTest() throws ModelException, FeatureException { + Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5missing", 1d); + + ModelMetadata meta = new RankSVMModel("test6", + "com.hello.im.a.bad.model.class", getFeatures(new String[] { + "constant1", "constant5"}), "test", fstore.getFeatures(), + new NamedParams().add("weights", weights)); + store.addMetadataModel(meta); + + } + + private class WrongClass {}; + + @Test(expected = ModelException.class) + public void badModelClassTest() throws ModelException, FeatureException { + Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5missing", 1d); + + ModelMetadata meta = new RankSVMModel("test7", + WrongClass.class.getCanonicalName(), getFeatures(new String[] { + "constant1", "constant5"}), "test", fstore.getFeatures(), + new NamedParams().add("weights", weights)); + store.addMetadataModel(meta); + + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestEdisMaxSolrFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestEdisMaxSolrFeature.java new file mode 100644 index 000000000000..90e7aaa04af9 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestEdisMaxSolrFeature.java @@ -0,0 +1,46 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.Test; + +public class TestEdisMaxSolrFeature extends TestQueryFeature { + @Test + public void testEdisMaxSolrFeature() throws Exception { + loadFeature( + "SomeEdisMax", + SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!edismax qf='title description' pf='description' mm=100% boost='pow(popularity, 0.1)' v='w1' tie=0.1}\"}"); + + loadModel("EdisMax-model", RankSVMModel.class.getCanonicalName(), + new String[] {"SomeEdisMax"}, "{\"weights\":{\"SomeEdisMax\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + query.add("rq", "{!ltr model=EdisMax-model reRankDocs=4}"); + query.set("debugQuery", "on"); + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestExternalFeatures.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestExternalFeatures.java new file mode 100644 index 000000000000..163745df6114 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestExternalFeatures.java @@ -0,0 +1,119 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.LTRComponent; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestExternalFeatures extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", + "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(commit()); + + loadFeatures("external_features.json"); + loadModels("external_model.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void externalTest1() throws Exception { + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score"); + query.add("rows", "3"); + query.add(LTRComponent.LTRParams.FV, "true"); + + // Regular scores + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==1.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==1.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==1.0"); + + query.add("fl", "[fv]"); + // Model is not specified so we should get a model does not exist exception + assertJQ("/query" + query.toQueryString(), "/error/msg=='model is null'"); + + // No match scores since user_query not passed in to external feature info + // and feature depended on it. + query.add("rq", "{!ltr reRankDocs=3 model=externalmodel}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==0.0"); + + // Matched user query since it was passed in + query.remove("rq"); + query + .add("rq", "{!ltr reRankDocs=3 model=externalmodel efi.user_query=w3}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='3'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/score==0.999"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==0.0"); + + System.out.println(restTestHarness.query("/query" + query.toQueryString())); + } + + @Test + public void externalStopwordTest() throws Exception { + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score,fv:[fv]"); + query.add("rows", "1"); + query.add(LTRComponent.LTRParams.FV, "true"); + // Stopword only query passed in + query.add("rq", + "{!ltr reRankDocs=3 model=externalmodel efi.user_query='a'}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv==''"); + + System.out.println(restTestHarness.query("/query" + query.toQueryString())); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFieldLengthFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFieldLengthFeature.java new file mode 100644 index 000000000000..7f66377259ce --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFieldLengthFeature.java @@ -0,0 +1,120 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestFieldLengthFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid")); + assertU(adoc("id", "3", "title", "w3", "description", "w3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testRanking() throws Exception { + loadFeature("title-length", FieldLengthFeature.class.getCanonicalName(), + "{\"field\":\"title\"}"); + + loadModel("title-model", RankSVMModel.class.getCanonicalName(), + new String[] {"title-length"}, "{\"weights\":{\"title-length\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + String res; + // res = restTestHarness.query("/query" + query.toQueryString()); + // System.out.println(res); + // Normal term match + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + // Normal term match + + query.add("rq", "{!ltr model=title-model reRankDocs=4}"); + res = restTestHarness.query("/query" + query.toQueryString()); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + + query.setQuery("*:*"); + query.remove("rows"); + query.add("rows", "8"); + query.remove("rq"); + query.add("rq", "{!ltr model=title-model reRankDocs=8}"); + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='6'"); + + loadFeature("description-length", + FieldLengthFeature.class.getCanonicalName(), + "{\"field\":\"description\"}"); + loadModel("description-model", RankSVMModel.class.getCanonicalName(), + new String[] {"description-length"}, + "{\"weights\":{\"description-length\":1.0}}"); + query.setQuery("title:w1"); + query.remove("rq"); + query.remove("rows"); + query.add("rows", "4"); + query.add("rq", "{!ltr model=description-model reRankDocs=4}"); + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFieldValueFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFieldValueFeature.java new file mode 100644 index 000000000000..da206ae23577 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFieldValueFeature.java @@ -0,0 +1,105 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestFieldValueFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testRanking() throws Exception { + loadFeature("popularity", FieldValueFeature.class.getCanonicalName(), + "{\"field\":\"popularity\"}"); + + loadModel("popularity-model", RankSVMModel.class.getCanonicalName(), + new String[] {"popularity"}, "{\"weights\":{\"popularity\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + String res; + // res = restTestHarness.query("/query" + query.toQueryString()); + // System.out.println(res); + // Normal term match + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + // Normal term match + + query.add("rq", "{!ltr model=popularity-model reRankDocs=4}"); + res = restTestHarness.query("/query" + query.toQueryString()); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + + query.setQuery("*:*"); + query.remove("rows"); + query.add("rows", "8"); + query.remove("rq"); + query.add("rq", "{!ltr model=popularity-model reRankDocs=8}"); + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='5'"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFilterSolrFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFilterSolrFeature.java new file mode 100644 index 000000000000..5957b8d0aa1d --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestFilterSolrFeature.java @@ -0,0 +1,90 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestFilterSolrFeature extends TestRerankBase { + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w1", "description", "w1", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w6 w2 w3 w4 w5 w8", "popularity", "88888")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "88888")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testUserTermScoreWithFQ() throws Exception { + loadFeature("SomeTermFQ", SolrFeature.class.getCanonicalName(), + "{\"fq\":[\"{!terms f=popularity}88888\"]}"); + loadFeature("SomeEfiFQ", SolrFeature.class.getCanonicalName(), + "{\"fq\":[\"{!terms f=title}${user_query}\"]}"); + loadModel("Term-modelFQ", RankSVMModel.class.getCanonicalName(), + new String[] {"SomeTermFQ", "SomeEfiFQ"}, + "{\"weights\":{\"SomeTermFQ\":1.6, \"SomeEfiFQ\":2.0}}"); + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score"); + query.add("rows", "3"); + query.add("fq", "{!terms f=title}w1"); + query.add("rq", + "{!ltr model=Term-modelFQ reRankDocs=5 efi.user_query='w5'}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==5"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==3.6"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==1.6"); + } + + @Test + public void testBadFeature() throws Exception { + // Missing q/fq + String feature = getFeatureInJson("badFeature", "test", + SolrFeature.class.getCanonicalName(), "{\"df\":\"foo\"]}"); + assertJPut(FEATURE_ENDPOINT, feature, "/responseHeader/status==500"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestLambdaMARTModel.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestLambdaMARTModel.java new file mode 100644 index 000000000000..9256b4d3c1e2 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestLambdaMARTModel.java @@ -0,0 +1,219 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//import static org.junit.internal.matchers.StringContains.containsString; + +import java.lang.invoke.MethodHandles; + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.common.SolrException; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.LTRComponent; +import org.apache.solr.ltr.util.ModelException; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestLambdaMARTModel extends TestRerankBase { + + @SuppressWarnings("unused") + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + h.update(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + h.update(adoc("id", "2", "title", "w2", "description", "w2", "popularity", + "2")); + h.update(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + h.update(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + h.update(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + h.update(commit()); + + loadFeatures("lambdamart_features.json"); // currently needed to force + // scoring on all docs + loadModels("lambdamart_model.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Ignore + @Test + public void lambdaMartTest1() throws Exception { + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("fl", "*,score"); + + // Regular scores + // System.out.println(restTestHarness.query(request) + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==1.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==1.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==1.0"); + + // No match scores since user_query not passed in to external feature info + // and feature depended on it. + query.add("fl", "[fv]"); + query.add("rq", "{!ltr reRankDocs=3 model=lambdamartmodel}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/score==-120.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[1]/score==-120.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[2]/score==-120.0"); + + // Matched user query since it was passed in + query.remove("rq"); + query.add("rq", + "{!ltr reRankDocs=3 model=lambdamartmodel efi.user_query=w3}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==30.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='1'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[1]/score==-120.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[2]/score==-120.0"); + + System.out.println(restTestHarness.query("/query" + query.toQueryString())); + + } + + @Ignore + @Test + public void lambdaMartTestExplain() throws Exception { + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score,[fv]"); + query.add("rows", "3"); + query.add(LTRComponent.LTRParams.FV, "true"); + + query.add("rq", + "{!ltr reRankDocs=3 model=lambdamartmodel efi.user_query=w3}"); + + // test out the explain feature, make sure it returns something + query.setParam("debugQuery", "on"); + String qryResult = JQ("/query" + query.toQueryString()); + + System.out.println(qryResult); + + qryResult = qryResult.replaceAll("\n", " "); + // FIXME containsString doesn't exist. + // assertThat(qryResult, containsString("\"debug\":{")); + // qryResult = qryResult.substring(qryResult.indexOf("debug")); + // + // assertThat(qryResult, containsString("\"explain\":{")); + // qryResult = qryResult.substring(qryResult.indexOf("explain")); + // + // assertThat(qryResult, containsString("lambdamartmodel")); + // assertThat(qryResult, + // containsString("org.apache.solr.ltr.ranking.LambdaMARTModel")); + // + // assertThat(qryResult, containsString("-100.0 = tree 0")); + // assertThat(qryResult, containsString("50.0 = tree 0")); + // assertThat(qryResult, containsString("-20.0 = tree 1")); + // assertThat(qryResult, containsString("'matchedTitle':1.0 > 0.5")); + // assertThat(qryResult, containsString("'matchedTitle':0.0 <= 0.5")); + // + // assertThat(qryResult, containsString(" Go Right ")); + // assertThat(qryResult, containsString(" Go Left ")); + // assertThat(qryResult, + // containsString("'this_feature_doesnt_exist' does not exist in FV")); + + System.out.println(restTestHarness.query("/query" + query.toQueryString())); + } + + @Test(expected = ModelException.class) + public void lambdaMartTestNoParams() throws Exception { + createModelFromFiles("lambdamart_model_no_params.json", + "lambdamart_features.json"); + + } + + @Test(expected = ModelException.class) + public void lambdaMartTestNoTrees() throws Exception { + createModelFromFiles("lambdamart_model_no_trees.json", + "lambdamart_features.json"); + } + + @Test(expected = ModelException.class) + public void lambdaMartTestNoWeight() throws Exception { + createModelFromFiles("lambdamart_model_no_weight.json", + "lambdamart_features.json"); + } + + @Test(expected = ModelException.class) + public void lambdaMartTestNoTree() throws Exception { + createModelFromFiles("lambdamart_model_no_tree.json", + "lambdamart_features.json"); + } + + @Test(expected = SolrException.class) + public void lambdaMartTestNoFeatures() throws Exception { + createModelFromFiles("lambdamart_model_no_features.json", + "lambdamart_features.json"); + } + + @Test(expected = ModelException.class) + public void lambdaMartTestNoRight() throws Exception { + createModelFromFiles("lambdamart_model_no_right.json", + "lambdamart_features.json"); + + } + + @Test(expected = ModelException.class) + public void lambdaMartTestNoLeft() throws Exception { + createModelFromFiles("lambdamart_model_no_left.json", + "lambdamart_features.json"); + } + + @Test(expected = ModelException.class) + public void lambdaMartTestNoThreshold() throws Exception { + createModelFromFiles("lambdamart_model_no_threshold.json", + "lambdamart_features.json"); + + } + + @Test(expected = ModelException.class) + public void lambdaMartTestNoFeature() throws Exception { + createModelFromFiles("lambdamart_model_no_feature.json", + "lambdamart_features.json"); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestNoMatchSolrFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestNoMatchSolrFeature.java new file mode 100644 index 000000000000..5d0e97ef5535 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestNoMatchSolrFeature.java @@ -0,0 +1,206 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.ArrayList; +import java.util.Map; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.LambdaMARTModel; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.noggit.ObjectBuilder; + +public class TestNoMatchSolrFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + + loadFeature("nomatchfeature", SolrFeature.class.getCanonicalName(), + "{\"q\":\"foobarbat12345\",\"df\":\"title\"}"); + loadFeature("yesmatchfeature", SolrFeature.class.getCanonicalName(), + "{\"q\":\"w1\",\"df\":\"title\"}"); + loadFeature("nomatchfeature2", SolrFeature.class.getCanonicalName(), + "{\"q\":\"foobarbat12345\",\"df\":\"title\"}"); + loadModel( + "nomatchmodel", + RankSVMModel.class.getCanonicalName(), + new String[] {"nomatchfeature", "yesmatchfeature", "nomatchfeature2"}, + "{\"weights\":{\"nomatchfeature\":1.0,\"yesmatchfeature\":1.1,\"nomatchfeature2\":1.1}}"); + + loadFeature("nomatchfeature3", SolrFeature.class.getCanonicalName(), + "{\"q\":\"foobarbat12345\",\"df\":\"title\"}"); + loadModel("nomatchmodel2", RankSVMModel.class.getCanonicalName(), + new String[] {"nomatchfeature3"}, + "{\"weights\":{\"nomatchfeature3\":1.0}}"); + + loadFeature("nomatchfeature4", SolrFeature.class.getCanonicalName(), + "noMatchFeaturesStore", "{\"q\":\"foobarbat12345\",\"df\":\"title\"}"); + loadModel("nomatchmodel3", RankSVMModel.class.getCanonicalName(), + new String[] {"nomatchfeature4"}, "noMatchFeaturesStore", + "{\"weights\":{\"nomatchfeature4\":1.0}}"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testNoMatchSolrFeat1() throws Exception { + // Tests model with all no matching features but 1 + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,fv:[fv]"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=nomatchmodel reRankDocs=4}"); + + SolrQuery yesMatchFeatureQuery = new SolrQuery(); + yesMatchFeatureQuery.setQuery("title:w1"); + yesMatchFeatureQuery.add("fl", "score"); + yesMatchFeatureQuery.add("rows", "4"); + String res = restTestHarness.query("/query" + + yesMatchFeatureQuery.toQueryString()); + System.out.println(res); + Map jsonParse = (Map) ObjectBuilder + .fromJSON(res); + Double doc0Score = (Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(0)).get("score"); + + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==" + + doc0Score * 1.1); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/fv=='yesmatchfeature:" + doc0Score + "'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/fv==''"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/fv==''"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='4'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/fv==''"); + } + + @Test + public void testNoMatchSolrFeat2() throws Exception { + // Tests model with all no matching features, but 1 non-modal matching + // feature for logging + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,fv:[fv]"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=nomatchmodel2 reRankDocs=4}"); + + SolrQuery yesMatchFeatureQuery = new SolrQuery(); + yesMatchFeatureQuery.setQuery("title:w1"); + yesMatchFeatureQuery.add("fl", "score"); + yesMatchFeatureQuery.add("rows", "4"); + String res = restTestHarness.query("/query" + + yesMatchFeatureQuery.toQueryString()); + System.out.println(res); + Map jsonParse = (Map) ObjectBuilder + .fromJSON(res); + Double doc0Score = (Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(0)).get("score"); + + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/fv=='yesmatchfeature:" + doc0Score + "'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/fv==''"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/fv==''"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/fv==''"); + } + + @Test + public void testNoMatchSolrFeat3() throws Exception { + // Tests model with all no matching features + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,fv:[fv]"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=nomatchmodel3 reRankDocs=4}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv==''"); + } + + @Test + public void testNoMatchSolrFeat4() throws Exception { + // Tests model with all no matching features but expects a non 0 score + loadModel( + "nomatchmodel4", + LambdaMARTModel.class.getCanonicalName(), + new String[] {"nomatchfeature4"}, + "noMatchFeaturesStore", + "{\"trees\":[{\"weight\":1.0, \"tree\":{\"feature\": \"matchedTitle\",\"threshold\": 0.5,\"left\":{\"value\" : -10},\"right\":{\"value\" : 9}}}]}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,fv:[fv]"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=nomatchmodel4 reRankDocs=4}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/score==-10.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv==''"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestOriginalScoreFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestOriginalScoreFeature.java new file mode 100644 index 000000000000..c90243815061 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestOriginalScoreFeature.java @@ -0,0 +1,169 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.JSONTestUtil; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.common.util.StrUtils; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.LTRComponent; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.noggit.ObjectBuilder; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestOriginalScoreFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + assertU(adoc("id", "1", "title", "w1")); + assertU(adoc("id", "2", "title", "w2")); + assertU(adoc("id", "3", "title", "w3")); + assertU(adoc("id", "4", "title", "w4")); + assertU(adoc("id", "5", "title", "w5")); + assertU(adoc("id", "6", "title", "w1 w2")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testOriginalScore() throws Exception { + loadFeature("score", OriginalScoreFeature.class.getCanonicalName(), "{}"); + + loadModel("originalScore", RankSVMModel.class.getCanonicalName(), + new String[] {"score"}, "{\"weights\":{\"score\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("wt", "json"); + + // String res = restTestHarness.query("/query" + query.toQueryString()); + // System.out.println(res); + + // Normal term match + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + Map jsonParse = (Map) ObjectBuilder + .fromJSON(res); + String doc0Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(0)).get("score")).toString(); + String doc1Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(1)).get("score")).toString(); + String doc2Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(2)).get("score")).toString(); + String doc3Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(3)).get("score")).toString(); + + query.add("fl", "[fv]"); + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr model=originalScore reRankDocs=4}"); + + // res = restTestHarness.query("/query" + query.toQueryString()); + // System.out.println(res); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==" + + doc0Score); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==" + + doc1Score); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==" + + doc2Score); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==" + + doc3Score); + } + + @Test + public void testOriginalScoreWithNonScoringFeatures() throws Exception { + loadFeature("origScore", OriginalScoreFeature.class.getCanonicalName(), + "store2", "{}"); + loadFeature("c2", ValueFeature.class.getCanonicalName(), "store2", + "{\"value\":2.0}"); + + loadModel("origScore", RankSVMModel.class.getCanonicalName(), + new String[] {"origScore"}, "store2", + "{\"weights\":{\"origScore\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score, fv:[fv]"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr model=origScore reRankDocs=4}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + Map jsonParse = (Map) ObjectBuilder + .fromJSON(res); + String doc0Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(0)).get("score")).toString(); + String doc1Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(1)).get("score")).toString(); + String doc2Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(2)).get("score")).toString(); + String doc3Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(3)).get("score")).toString(); + System.out.println(doc0Score); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/fv=='origScore:" + doc0Score + ";c2:2.0'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[1]/fv=='origScore:" + doc1Score + ";c2:2.0'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[2]/fv=='origScore:" + doc2Score + ";c2:2.0'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[3]/fv=='origScore:" + doc3Score + ";c2:2.0'"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestQueryFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestQueryFeature.java new file mode 100644 index 000000000000..ecd3b01dc313 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestQueryFeature.java @@ -0,0 +1,99 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.ltr.TestRerankBase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestQueryFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + /* + * @Test public void testEdisMaxSolrFeatureFilterQ() throws Exception { + * //before(); loadFeature("SomeEdisMaxFQ", + * SolrFeature.class.getCanonicalName(), + * "{\"fq\":[\"{!edismax qf='title description' pf='description' mm=100% boost='pow(popularity, 0.1)' v='w1' tie=0.1}\"]}" + * ); + * + * loadModel("EdisMax-modelFQ", RankSVMModel.class.getCanonicalName(), new + * String[] { "SomeEdisMaxFQ" }, "{\"weights\":{\"SomeEdisMaxFQ\":1.0}}"); + * + * SolrQuery query = new SolrQuery(); query.setQuery("title:w1"); + * query.add("fl", "*, score"); query.add("rows", "4"); query.add("rq", + * "{!ltr model=EdisMax-modelFQ reRankDocs=4}"); query.set("debugQuery", + * "on"); String res = restTestHarness.query("/query?" + query.toString()); + * System.out.println(res); assertJQ("/query?" + query.toString(), + * "/response/numFound/==4"); //aftertest(); } + * + * + * public void testUserTermScoreWithFQ() throws Exception { before(); + * loadFeature("SomeTermFQ", SolrFeature.class.getCanonicalName(), + * "{\"fq\":[\"{!terms f=popularity}88888\"]}"); loadModel("Term-modelFQ", + * RankSVMModel.class.getCanonicalName(), new String[] { "SomeTermFQ" }, + * "{\"weights\":{\"SomeTermFQ\":1.5}}"); SolrQuery query = new SolrQuery(); + * query.setQuery("title:w1"); query.add("fl", "*, score"); query.add("rows", + * "4"); query.add("rq", "{!ltr model=Term-modelFQ reRankDocs=4}"); + * query.set("debugQuery", "on"); String res = restTestHarness.query("/query?" + * + query.toString()); System.out.println(res); assertJQ("/query?" + + * query.toString(), "/response/numFound/==4"); assertJQ("/query?" + + * query.toString(), "/response/docs/[0]/score==0.0"); assertJQ("/query?" + + * query.toString(), "/response/docs/[1]/score==0.0"); aftertest(); } + * + * public void testUserTermScorerQWithQuery() throws Exception { before(); + * loadFeature("matchedTitle", SolrFeature.class.getCanonicalName(), + * "{\"q\":\"title:QUERY\"}"); loadModel("Term-matchedTitle", + * RankSVMModel.class.getCanonicalName(), new String[] { "matchedTitle" }, + * "{\"weights\":{\"matchedTitle\":1.0}}"); SolrQuery query = new SolrQuery(); + * query.setQuery("title:w1"); query.add("fl", "*, score"); query.add("rows", + * "4"); query.add("rq", "{!ltr model=Term-matchedTitle reRankDocs=4}"); + * query.set("debugQuery", "on"); String res = restTestHarness.query("/query?" + * + query.toString()); System.out.println(res); assertJQ("/query?" + + * query.toString(), "/response/numFound/==4"); aftertest(); } + */ + +} \ No newline at end of file diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestRankingFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestRankingFeature.java new file mode 100644 index 000000000000..79f0630b75a4 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestRankingFeature.java @@ -0,0 +1,80 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.Test; + +public class TestRankingFeature extends TestQueryFeature { + @Test + public void testRankingSolrFeature() throws Exception { + // before(); + loadFeature("powpularityS", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!func}pow(popularity,2)\"}"); + loadFeature("unpopularityS", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!func}div(1,popularity)\"}"); + + loadModel("powpularityS-model", RankSVMModel.class.getCanonicalName(), + new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}"); + loadModel("unpopularityS-model", RankSVMModel.class.getCanonicalName(), + new String[] {"unpopularityS"}, "{\"weights\":{\"unpopularityS\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + // Normal term match + + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4}"); + query.set("debugQuery", "on"); + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==64.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==49.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==36.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0"); + + query.remove("rq"); + query.add("rq", "{!ltr model=unpopularityS-model reRankDocs=4}"); + + query.set("debugQuery", "on"); + res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==1.0"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='8'"); + // aftertest(); + + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScoreWithQ.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScoreWithQ.java new file mode 100644 index 000000000000..d978a6bcf370 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScoreWithQ.java @@ -0,0 +1,45 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.Test; + +public class TestUserTermScoreWithQ extends TestQueryFeature { + @Test + public void testUserTermScoreWithQ() throws Exception { + // before(); + loadFeature("SomeTermQ", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!terms f=popularity}88888\"}"); + loadModel("Term-modelQ", RankSVMModel.class.getCanonicalName(), + new String[] {"SomeTermQ"}, "{\"weights\":{\"SomeTermQ\":1.0}}"); + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("rq", "{!ltr model=Term-modelQ reRankDocs=4}"); + query.set("debugQuery", "on"); + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + // aftertest(); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScorerQuery.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScorerQuery.java new file mode 100644 index 000000000000..4e3a0a7e731d --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScorerQuery.java @@ -0,0 +1,45 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.Test; + +public class TestUserTermScorerQuery extends TestQueryFeature { + @Test + public void testUserTermScorerQuery() throws Exception { + // before(); + loadFeature("matchedTitleDFExt", SolrFeature.class.getCanonicalName(), + "{\"q\":\"${user_query}\",\"df\":\"title\"}"); + loadModel("Term-matchedTitleDFExt", RankSVMModel.class.getCanonicalName(), + new String[] {"matchedTitleDFExt"}, + "{\"weights\":{\"matchedTitleDFExt\":1.1}}"); + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("rq", + "{!ltr model=Term-matchedTitleDFExt reRankDocs=4 efi.user_query=w8}"); + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + // aftertest(); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScorereQDF.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScorereQDF.java new file mode 100644 index 000000000000..a7928d703f38 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestUserTermScorereQDF.java @@ -0,0 +1,46 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.Test; + +public class TestUserTermScorereQDF extends TestQueryFeature { + @Test + public void testUserTermScorerQWithDF() throws Exception { + // before(); + loadFeature("matchedTitleDF", SolrFeature.class.getCanonicalName(), + "{\"q\":\"w5\",\"df\":\"title\"}"); + loadModel("Term-matchedTitleDF", RankSVMModel.class.getCanonicalName(), + new String[] {"matchedTitleDF"}, + "{\"weights\":{\"matchedTitleDF\":1.0}}"); + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "2"); + query.add("rq", "{!ltr model=Term-matchedTitleDF reRankDocs=4}"); + query.set("debugQuery", "on"); + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + // aftertest(); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestValueFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestValueFeature.java new file mode 100644 index 000000000000..82283e18e590 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/impl/TestValueFeature.java @@ -0,0 +1,149 @@ +package org.apache.solr.ltr.feature.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.LTRComponent; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestValueFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + + assertU(adoc("id", "1", "title", "w1")); + assertU(adoc("id", "2", "title", "w2")); + assertU(adoc("id", "3", "title", "w3")); + assertU(adoc("id", "4", "title", "w4")); + assertU(adoc("id", "5", "title", "w5")); + assertU(adoc("id", "6", "title", "w1 w2")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test(expected = Exception.class) + public void testValueFeature1() throws Exception { + loadFeature("c1", ValueFeature.class.getCanonicalName(), "{}"); + } + + @Test(expected = Exception.class) + public void testValueFeature2() throws Exception { + loadFeature("c2", ValueFeature.class.getCanonicalName(), "{\"value\":\"\"}"); + } + + @Test(expected = Exception.class) + public void testValueFeature3() throws Exception { + loadFeature("c2", ValueFeature.class.getCanonicalName(), + "{\"value\":\" \"}"); + } + + @Test + public void testValueFeature4() throws Exception { + loadFeature("c3", ValueFeature.class.getCanonicalName(), "c3", + "{\"value\":2}"); + loadModel("m3", RankSVMModel.class.getCanonicalName(), new String[] {"c3"}, + "c3", "{\"weights\":{\"c3\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=m3 reRankDocs=4}"); + + // String res = restTestHarness.query("/query" + query.toQueryString()); + // System.out.println("\n\n333333\n\n" + res); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==2.0"); + } + + @Test + public void testValueFeature5() throws Exception { + loadFeature("c4", ValueFeature.class.getCanonicalName(), "c4", + "{\"value\":\"2\"}"); + loadModel("m4", RankSVMModel.class.getCanonicalName(), new String[] {"c4"}, + "c4", "{\"weights\":{\"c4\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=m4 reRankDocs=4}"); + + // String res = restTestHarness.query("/query" + query.toQueryString()); + // System.out.println("\n\n44444\n\n" + res); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==2.0"); + } + + @Test + public void testValueFeature6() throws Exception { + loadFeature("c5", ValueFeature.class.getCanonicalName(), "c5", + "{\"value\":\"${val5}\"}"); + loadModel("m5", RankSVMModel.class.getCanonicalName(), new String[] {"c5"}, + "c5", "{\"weights\":{\"c5\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score,fvonly:[fvonly]"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr model=m5 reRankDocs=4}"); + + // String res = restTestHarness.query("/query" + query.toQueryString()); + // System.out.println(res); + + // No efi.val passed in + assertJQ("/query" + query.toQueryString(), "/responseHeader/status==400"); + } + + @Test + public void testValueFeature7() throws Exception { + loadFeature("c6", ValueFeature.class.getCanonicalName(), "c6", + "{\"value\":\"${val6}\"}"); + loadModel("m6", RankSVMModel.class.getCanonicalName(), new String[] {"c6"}, + "c6", "{\"weights\":{\"c6\":1.0}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=m6 reRankDocs=4 efi.val6='2'}"); + + // String res = restTestHarness.query("/query" + query.toQueryString()); + // System.out.println(res); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==2.0"); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/norm/impl/TestMinMaxNormalizer.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/norm/impl/TestMinMaxNormalizer.java new file mode 100644 index 000000000000..b06c09afaf41 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/norm/impl/TestMinMaxNormalizer.java @@ -0,0 +1,89 @@ +package org.apache.solr.ltr.feature.norm.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import static org.junit.Assert.assertEquals; + +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.rest.ManagedModelStore; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.ltr.util.NormalizerException; +import org.junit.Test; + +public class TestMinMaxNormalizer { + + @Test(expected = NormalizerException.class) + public void testInvalidMinMaxNoParams() throws NormalizerException { + ManagedModelStore.getNormalizerInstance( + MinMaxNormalizer.class.getCanonicalName(), new NamedParams()); + + } + + @Test(expected = NormalizerException.class) + public void testInvalidMinMaxMissingMax() throws NormalizerException { + + ManagedModelStore.getNormalizerInstance( + MinMaxNormalizer.class.getCanonicalName(), + new NamedParams().add("min", 0f)); + + } + + @Test(expected = NormalizerException.class) + public void testInvalidMinMaxMissingMin() throws NormalizerException { + + ManagedModelStore.getNormalizerInstance( + MinMaxNormalizer.class.getCanonicalName(), + new NamedParams().add("max", 0f)); + + } + + @Test(expected = NormalizerException.class) + public void testInvalidMinMaxMissingInvalidDelta() throws NormalizerException { + ManagedModelStore.getNormalizerInstance( + MinMaxNormalizer.class.getCanonicalName(), + new NamedParams().add("max", 0f).add("min", 10f)); + } + + @Test(expected = NormalizerException.class) + public void testInvalidMinMaxMissingInvalidDelta2() + throws NormalizerException { + + ManagedModelStore.getNormalizerInstance( + "org.apache.solr.ltr.feature.norm.impl.MinMaxNormalizer", + new NamedParams().add("min", 10f).add("max", 10f)); + // min == max + } + + @Test + public void testNormalizer() throws NormalizerException { + Normalizer n = ManagedModelStore.getNormalizerInstance( + MinMaxNormalizer.class.getCanonicalName(), + new NamedParams().add("min", 5f).add("max", 10f)); + + float value = 8; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + value = 100; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + value = 150; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + value = -1; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + value = 5; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/norm/impl/TestStandardNormalizer.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/norm/impl/TestStandardNormalizer.java new file mode 100644 index 000000000000..2178ff064168 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/norm/impl/TestStandardNormalizer.java @@ -0,0 +1,81 @@ +package org.apache.solr.ltr.feature.norm.impl; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import static org.junit.Assert.assertEquals; + +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.rest.ManagedModelStore; +import org.apache.solr.ltr.util.NamedParams; +import org.apache.solr.ltr.util.NormalizerException; +import org.junit.Test; + +public class TestStandardNormalizer { + + @Test(expected = NormalizerException.class) + public void testNormalizerNoParams() throws NormalizerException { + ManagedModelStore.getNormalizerInstance( + StandardNormalizer.class.getCanonicalName(), new NamedParams()); + + } + + @Test(expected = NormalizerException.class) + public void testInvalidSTD() throws NormalizerException { + + ManagedModelStore.getNormalizerInstance( + StandardNormalizer.class.getCanonicalName(), + new NamedParams().add("std", 0f)); + + } + + @Test(expected = NormalizerException.class) + public void testInvalidSTD2() throws NormalizerException { + + ManagedModelStore.getNormalizerInstance( + StandardNormalizer.class.getCanonicalName(), + new NamedParams().add("std", -1f)); + + } + + @Test(expected = NormalizerException.class) + public void testInvalidSTD3() throws NormalizerException { + ManagedModelStore.getNormalizerInstance( + StandardNormalizer.class.getCanonicalName(), + new NamedParams().add("avg", 1f).add("std", 0f)); + } + + @Test + public void testNormalizer() throws NormalizerException { + Normalizer identity = ManagedModelStore.getNormalizerInstance( + StandardNormalizer.class.getCanonicalName(), + new NamedParams().add("avg", 0f).add("std", 1f)); + + float value = 8; + assertEquals(value, identity.normalize(value), 0.0001); + value = 150; + assertEquals(value, identity.normalize(value), 0.0001); + Normalizer norm = ManagedModelStore.getNormalizerInstance( + StandardNormalizer.class.getCanonicalName(), + new NamedParams().add("avg", 10f).add("std", 1.5f)); + + for (float v : new float[] {10f, 20f, 25f, 30f, 31f, 40f, 42f, 100f, + 10000000f}) { + assertEquals((v - 10f) / (1.5f), norm.normalize(v), 0.0001); + } + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestLTRQParserExplain.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestLTRQParserExplain.java new file mode 100644 index 000000000000..645c010993f8 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestLTRQParserExplain.java @@ -0,0 +1,142 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.lang.invoke.MethodHandles; + +import org.apache.solr.SolrTestCaseJ4.SuppressSSL; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressSSL +public class TestLTRQParserExplain extends TestRerankBase { + + @SuppressWarnings("unused") + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + loadFeatures("features-store-test-model.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void checkReranked() throws Exception { + + loadModel("svm", RankSVMModel.class.getCanonicalName(), new String[] { + "constant1", "constant2"}, + "{\"weights\":{\"constant1\":1.5,\"constant2\":3.5}}"); + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "2"); + query.add("rq", "{!ltr reRankDocs=2 model=svm}"); + query.add("fl", "*,score"); + // query.add("wt","xml"); + // System.out.println(restTestHarness.query("/query" + + // query.toQueryString())); + // query.add("wt","json"); + // assertJQ( + // "/query" + query.toQueryString(), + // "/debug/explain/7=='\n8.5 = svm [ org.apache.solr.ltr.ranking.RankSVMModel ] model applied to features, sum of:\n 1.5 = prod of:\n 1.5 = weight on feature [would be cool to have the name :)]\n 1.0 = ValueFeature [name=constant1 value=1.0]\n 7.0 = prod of:\n 3.5 = weight on feature [would be cool to have the name :)]\n 2.0 = ValueFeature [name=constant2 value=2.0]\n'"); + query.add("wt", "xml"); + System.out.println(restTestHarness.query("/query" + query.toQueryString())); + } + + @Test + public void checkReranked2() throws Exception { + loadModel("svm2", RankSVMModel.class.getCanonicalName(), new String[] { + "constant1", "constant2", "pop"}, + "{\"weights\":{\"pop\":1.0,\"constant1\":1.5,\"constant2\":3.5}}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "2"); + query.add("rq", "{!ltr reRankDocs=2 model=svm2}"); + query.add("fl", "*,score"); + + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/9=='\n13.5 = svm2 [ org.apache.solr.ltr.ranking.RankSVMModel ] model applied to features, sum of:\n 1.5 = prod of:\n 1.5 = weight on feature [would be cool to have the name :)]\n 1.0 = ValueFeature [name=constant1 value=1.0]\n 7.0 = prod of:\n 3.5 = weight on feature [would be cool to have the name :)]\n 2.0 = ValueFeature [name=constant2 value=2.0]\n 5.0 = prod of:\n 1.0 = weight on feature [would be cool to have the name :)]\n 5.0 = FieldValueFeature [name=pop fields=[popularity]]\n'"); + query.add("wt", "xml"); + System.out.println(restTestHarness.query("/query" + query.toQueryString())); + } + + @Test + public void checkReranked3() throws Exception { + loadFeatures("features-ranksvm.json"); + loadModels("ranksvm-model.json"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "4"); + query.add("rq", "{!ltr reRankDocs=4 model=6029760550880411648}"); + query.add("fl", "*,score"); + query.add("wt", "xml"); + + System.out.println(restTestHarness.query("/query" + query.toQueryString())); + query.remove("wt"); + query.add("wt", "json"); + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/7=='\n3.5116758 = 6029760550880411648 [ org.apache.solr.ltr.ranking.RankSVMModel ] model applied to features, sum of:\n 0.0 = prod of:\n 0.0 = weight on feature [would be cool to have the name :)]\n 1.0 = ValueFeature [name=title value=1.0]\n 0.2 = prod of:\n 0.1 = weight on feature [would be cool to have the name :)]\n 2.0 = ValueFeature [name=description value=2.0]\n 0.4 = prod of:\n 0.2 = weight on feature [would be cool to have the name :)]\n 2.0 = ValueFeature [name=keywords value=2.0]\n 0.09 = prod of:\n 0.3 = weight on feature [would be cool to have the name :)]\n 0.3 = normalized using org.apache.solr.ltr.feature.norm.impl.MinMaxNormalizer [params {min=0.0, max=10.0}]\n 3.0 = ValueFeature [name=popularity value=3.0]\n 1.6 = prod of:\n 0.4 = weight on feature [would be cool to have the name :)]\n 4.0 = ValueFeature [name=text value=4.0]\n 0.6156155 = prod of:\n 0.1231231 = weight on feature [would be cool to have the name :)]\n 5.0 = ValueFeature [name=queryIntentPerson value=5.0]\n 0.60606056 = prod of:\n 0.12121211 = weight on feature [would be cool to have the name :)]\n 5.0 = ValueFeature [name=queryIntentCompany value=5.0]\n'}"); + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/9=='\n3.5116758 = 6029760550880411648 [ org.apache.solr.ltr.ranking.RankSVMModel ] model applied to features, sum of:\n 0.0 = prod of:\n 0.0 = weight on feature [would be cool to have the name :)]\n 1.0 = ValueFeature [name=title value=1.0]\n 0.2 = prod of:\n 0.1 = weight on feature [would be cool to have the name :)]\n 2.0 = ValueFeature [name=description value=2.0]\n 0.4 = prod of:\n 0.2 = weight on feature [would be cool to have the name :)]\n 2.0 = ValueFeature [name=keywords value=2.0]\n 0.09 = prod of:\n 0.3 = weight on feature [would be cool to have the name :)]\n 0.3 = normalized using org.apache.solr.ltr.feature.norm.impl.MinMaxNormalizer [params {min=0.0, max=10.0}]\n 3.0 = ValueFeature [name=popularity value=3.0]\n 1.6 = prod of:\n 0.4 = weight on feature [would be cool to have the name :)]\n 4.0 = ValueFeature [name=text value=4.0]\n 0.6156155 = prod of:\n 0.1231231 = weight on feature [would be cool to have the name :)]\n 5.0 = ValueFeature [name=queryIntentPerson value=5.0]\n 0.60606056 = prod of:\n 0.12121211 = weight on feature [would be cool to have the name :)]\n 5.0 = ValueFeature [name=queryIntentCompany value=5.0]\n'}"); + } + + // @Test + // public void checkfq() throws Exception { + // + // System.out.println("after: \n" + restTestHarness.query("/config/managed")); + // + // FunctionQueryFeature fq = new FunctionQueryFeature("log(popularity)"); + // assertJPut(featureEndpoint, gson.toJson(fq), "/responseHeader/status==0"); + // fq = new FunctionQueryFeature("tf_title_bloomberg", + // "tf(title,'bloomberg')"); + // assertJPut(featureEndpoint, gson.toJson(fq), "/responseHeader/status==0"); + // // fq.(new NamedParams().add("fq", "log(popularity)")); + // + // ModelMetadata model = new ModelMetadata("sum3", + // SumModel.class.getCanonicalName(), getFeatures(new String[] { + // "log(popularity)", "tf_title_bloomberg", "t1", "t2" })); + // + // assertJPut(modelEndpoint, gson.toJson(model), "/responseHeader/status==0"); + // + // SolrQuery query = new SolrQuery(); + // query.setQuery("title:bloomberg"); + // query.setParam("debugQuery", "on"); + // query.add("rows", "4"); + // query.add("rq", "{!ltr reRankDocs=4 model=sum3}"); + // query.add("fl", "*,score"); + // System.out.println(restTestHarness.query("/query" + + // query.toQueryString())); + // } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestLTRQParserPlugin.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestLTRQParserPlugin.java new file mode 100644 index 000000000000..607d6bff74e4 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestLTRQParserPlugin.java @@ -0,0 +1,139 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.lang.invoke.MethodHandles; + +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressCodecs({"Lucene3x", "Lucene41", "Lucene40", "Appending"}) +public class TestLTRQParserPlugin extends TestRerankBase { + + @SuppressWarnings("unused") + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema-ltr.xml"); + // store = getModelStore(); + bulkIndex(); + + loadFeatures("features-ranksvm.json"); + loadModels("ranksvm-model.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + // store.clear(); + } + + @Test + public void ltrModelIdMissingTest() throws Exception { + String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + // SolrQueryRequest req = req("q", solrQuery, "rows", "4", "fl", "*,score", + // "fv", "true", "rq", "{!ltr reRankDocs=100}"); + SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr reRankDocs=100}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + assert (res.contains("Must provide model in the request")); + + // h.query("/query", req); + + /* + * String solrQuery = + * "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + * SolrQueryRequest req = req("q", solrQuery, "rows", "4", "fl", "*,score", + * "fv", "true", "rq", "{!ltr reRankDocs=100}"); + * + * h.query("/query", req); + */ + } + + @Test + public void ltrModelIdDoesNotExistTest() throws Exception { + String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=-1 reRankDocs=100}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + assert (res.contains("cannot find model")); + /* + * String solrQuery = + * "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + * SolrQueryRequest req = req("q", solrQuery, "rows", "4", "fl", "*,score", + * "fv", "true", "rq", "{!ltr model=-1 reRankDocs=100}"); + * + * h.query("/query", req); + */ + } + + @Test + public void ltrMoreResultsThanReRankedTest() throws Exception { + String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3}"); + + String res = restTestHarness.query("/query" + query.toQueryString()); + System.out.println(res); + assert (res.contains("Requesting more documents than being reranked.")); + /* + * String solrQuery = + * "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + * SolrQueryRequest req = req("q", solrQuery, "rows", "999999", "fl", + * "*,score", "fv", "true", "rq", + * "{!ltr model=6029760550880411648 reRankDocs=100}"); + * + * h.query("/query", req); + */ + } + + @Test + public void ltrNoResultsTest() throws Exception { + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg23"); + query.add("fl", "*,[fv]"); + query.add("rows", "3"); + query.add("debugQuery", "on"); + query.add(LTRComponent.LTRParams.FV, "true"); + query.add("rq", "{!ltr reRankDocs=3 model=6029760550880411648}"); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==0"); + // assertJQ("/query?" + query.toString(), "/response/numFound/==0"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestModelQuery.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestModelQuery.java new file mode 100644 index 000000000000..e2f3d508ac3e --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestModelQuery.java @@ -0,0 +1,234 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.search.BooleanQuery.Builder; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.ltr.feature.impl.ValueFeature; +import org.apache.solr.ltr.feature.norm.Normalizer; +import org.apache.solr.ltr.util.FeatureException; +import org.apache.solr.ltr.util.ModelException; +import org.apache.solr.ltr.util.NamedParams; +import org.junit.Test; + +@SuppressCodecs("Lucene3x") +public class TestModelQuery extends LuceneTestCase { + + private IndexSearcher getSearcher(IndexReader r) { + IndexSearcher searcher = newSearcher(r, false, false); + return searcher; + } + + private static List makeFeatures(int[] featureIds) { + List features = new ArrayList<>(); + for (int i : featureIds) { + ValueFeature f = new ValueFeature(); + try { + f.init("f" + i, new NamedParams().add("value", i), i); + } catch (FeatureException e) { + e.printStackTrace(); + } + features.add(f); + } + return features; + } + + private static List makeNormalizedFeatures(int[] featureIds) { + List features = new ArrayList<>(); + for (int i : featureIds) { + ValueFeature f = new ValueFeature(); + f.name = "f" + i; + f.params = new NamedParams().add("value", i); + f.id = i; + f.norm = new Normalizer() { + + @Override + public float normalize(float value) { + return 42.42f; + } + }; + features.add(f); + } + return features; + } + + private static NamedParams makeFeatureWeights(List features) { + NamedParams nameParams = new NamedParams(); + HashMap modelWeights = new HashMap(); + for (Feature feat : features) { + modelWeights.put(feat.name, 0.1); + } + if (modelWeights.isEmpty()) modelWeights.put("", 0.0); + nameParams.add("weights", modelWeights); + return nameParams; + } + + private ModelQuery.ModelWeight performQuery(TopDocs hits, + IndexSearcher searcher, int docid, ModelQuery model) throws IOException, + ModelException { + List leafContexts = searcher.getTopReaderContext() + .leaves(); + int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts); + final LeafReaderContext context = leafContexts.get(n); + int deBasedDoc = hits.scoreDocs[0].doc - context.docBase; + + Weight weight = searcher.createNormalizedWeight(model, true); + Scorer scorer = weight.scorer(context); + + // rerank using the field final-score + scorer.iterator().advance(deBasedDoc); + float score = scorer.score(); + + // assertEquals(42.0f, score, 0.0001); + // assertTrue(weight instanceof AssertingWeight); + // (AssertingIndexSearcher) + assertTrue(weight instanceof ModelQuery.ModelWeight); + ModelQuery.ModelWeight modelWeight = (ModelQuery.ModelWeight) weight; + return modelWeight; + + } + + @Test + public void testModelQuery() throws IOException, ModelException { + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "0", Field.Store.YES)); + doc.add(newTextField("field", "wizard the the the the the oz", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 1.0f)); + + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + // 1 extra token, but wizard and oz are close; + doc.add(newTextField("field", "wizard oz the the the the the the", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 2.0f)); + w.addDocument(doc); + + IndexReader r = w.getReader(); + w.close(); + + // Do ordinary BooleanQuery: + Builder bqBuilder = new Builder(); + bqBuilder.add(new TermQuery(new Term("field", "wizard")), Occur.SHOULD); + bqBuilder.add(new TermQuery(new Term("field", "oz")), Occur.SHOULD); + IndexSearcher searcher = getSearcher(r); + // first run the standard query + TopDocs hits = searcher.search(bqBuilder.build(), 10); + assertEquals(2, hits.totalHits); + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + + List features = makeFeatures(new int[] {0, 1, 2}); + List allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9}); + RankSVMModel meta = new RankSVMModel("test", + RankSVMModel.class.getCanonicalName(), features, "test", allFeatures, + makeFeatureWeights(features)); + + ModelQuery.ModelWeight modelWeight = performQuery(hits, searcher, + hits.scoreDocs[0].doc, new ModelQuery(meta)); + assertEquals(3, modelWeight.modelFeatureValuesNormalized.length); + assertEquals(10, modelWeight.allFeatureValues.length); + + for (int i = 0; i < 3; i++) { + assertEquals((float) i, modelWeight.modelFeatureValuesNormalized[i], + 0.0001); + } + for (int i = 0; i < 10; i++) { + assertEquals((float) i, modelWeight.allFeatureValues[i], 0.0001); + } + + for (int i = 0; i < 10; i++) { + assertEquals("f" + i, modelWeight.allFeatureNames[i]); + + } + + int[] mixPositions = new int[] {8, 2, 4, 9, 0}; + features = makeFeatures(mixPositions); + meta = new RankSVMModel("test", RankSVMModel.class.getCanonicalName(), + features, "test", allFeatures, makeFeatureWeights(features)); + + modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, + new ModelQuery(meta)); + assertEquals(mixPositions.length, + modelWeight.modelFeatureValuesNormalized.length); + + for (int i = 0; i < mixPositions.length; i++) { + assertEquals((float) mixPositions[i], + modelWeight.modelFeatureValuesNormalized[i], 0.0001); + } + for (int i = 0; i < 10; i++) { + assertEquals((float) i, modelWeight.allFeatureValues[i], 0.0001); + } + + int[] noPositions = new int[] {}; + features = makeFeatures(noPositions); + meta = new RankSVMModel("test", RankSVMModel.class.getCanonicalName(), + features, "test", allFeatures, makeFeatureWeights(features)); + + modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, + new ModelQuery(meta)); + assertEquals(0, modelWeight.modelFeatureValuesNormalized.length); + + // test normalizers + features = makeNormalizedFeatures(mixPositions); + RankSVMModel normMeta = new RankSVMModel("test", + RankSVMModel.class.getCanonicalName(), features, "test", allFeatures, + makeFeatureWeights(features)); + + modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, + new ModelQuery(normMeta)); + assertEquals(mixPositions.length, + modelWeight.modelFeatureValuesNormalized.length); + for (int i = 0; i < mixPositions.length; i++) { + assertEquals(42.42f, modelWeight.modelFeatureValuesNormalized[i], 0.0001); + } + for (int i = 0; i < 10; i++) { + assertEquals((float) i, modelWeight.allFeatureValues[i], 0.0001); + } + r.close(); + dir.close(); + + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestReRankingPipeline.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestReRankingPipeline.java new file mode 100644 index 000000000000..cb90d1fc0886 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/ranking/TestReRankingPipeline.java @@ -0,0 +1,286 @@ +package org.apache.solr.ltr.ranking; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.search.BooleanQuery.Builder; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Scorer.ChildScorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util.LuceneTestCase.SuppressCodecs; +import org.apache.solr.ltr.feature.ModelMetadata; +import org.apache.solr.ltr.feature.impl.FieldValueFeature; +import org.apache.solr.ltr.ranking.ModelQuery.ModelWeight; +import org.apache.solr.ltr.ranking.ModelQuery.ModelWeight.ModelScorer; +import org.apache.solr.ltr.util.NamedParams; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressCodecs("Lucene3x") +public class TestReRankingPipeline extends LuceneTestCase { + + private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + private IndexSearcher getSearcher(IndexReader r) { + IndexSearcher searcher = newSearcher(r); + + return searcher; + } + + private static List makeFieldValueFeatures(int[] featureIds, + String field) { + List features = new ArrayList<>(); + for (int i : featureIds) { + FieldValueFeature f = new FieldValueFeature(); + f.name = "f" + i; + f.params = new NamedParams().add("field", field); + features.add(f); + } + return features; + } + + private class MockModel extends ModelMetadata { + + public MockModel(String name, String type, List features, + String featureStoreName, Collection allFeatures, + NamedParams params) { + super(name, type, features, featureStoreName, allFeatures, params); + } + + @Override + public float score(float[] modelFeatureValuesNormalized) { + return modelFeatureValuesNormalized[2]; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc, + float finalScore, List featureExplanations) { + return null; + } + + } + + @Ignore + @Test + public void testRescorer() throws IOException { + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "0", Field.Store.YES)); + doc.add(newTextField("field", "wizard the the the the the oz", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 1.0f)); + + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + // 1 extra token, but wizard and oz are close; + doc.add(newTextField("field", "wizard oz the the the the the the", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 2.0f)); + w.addDocument(doc); + + IndexReader r = w.getReader(); + w.close(); + + // Do ordinary BooleanQuery: + Builder bqBuilder = new Builder(); + bqBuilder.add(new TermQuery(new Term("field", "wizard")), Occur.SHOULD); + bqBuilder.add(new TermQuery(new Term("field", "oz")), Occur.SHOULD); + IndexSearcher searcher = getSearcher(r); + // first run the standard query + TopDocs hits = searcher.search(bqBuilder.build(), 10); + assertEquals(2, hits.totalHits); + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + + List features = makeFieldValueFeatures(new int[] {0, 1, 2}, + "final-score"); + List allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, + 4, 5, 6, 7, 8, 9}, "final-score"); + RankSVMModel meta = new RankSVMModel("test", + MockModel.class.getCanonicalName(), features, "test", allFeatures, null); + + LTRRescorer rescorer = new LTRRescorer(new ModelQuery(meta)); + hits = rescorer.rescore(searcher, hits, 2); + + // rerank using the field final-score + assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id")); + + r.close(); + dir.close(); + + } + + @Ignore + @Test + public void testDifferentTopN() throws IOException { + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "0", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 1.0f)); + w.addDocument(doc); + + doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 2.0f)); + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "2", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 3.0f)); + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "3", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz oz the the the the ", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 4.0f)); + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "4", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz the the the the the the", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 5.0f)); + w.addDocument(doc); + + IndexReader r = w.getReader(); + w.close(); + + // Do ordinary BooleanQuery: + Builder bqBuilder = new Builder(); + bqBuilder.add(new TermQuery(new Term("field", "wizard")), Occur.SHOULD); + bqBuilder.add(new TermQuery(new Term("field", "oz")), Occur.SHOULD); + IndexSearcher searcher = getSearcher(r); + + // first run the standard query + TopDocs hits = searcher.search(bqBuilder.build(), 10); + assertEquals(5, hits.totalHits); + for (int i = 0; i < 5; i++) { + System.out.print(hits.scoreDocs[i].doc + " -> "); + System.out.println(searcher.doc(hits.scoreDocs[i].doc).get("id")); + } + + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id")); + assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id")); + assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id")); + + List features = makeFieldValueFeatures(new int[] {0, 1, 2}, + "final-score"); + List allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, + 4, 5, 6, 7, 8, 9}, "final-score"); + RankSVMModel meta = new RankSVMModel("test", + MockModel.class.getCanonicalName(), features, "test", allFeatures, null); + + LTRRescorer rescorer = new LTRRescorer(new ModelQuery(meta)); + + // rerank @ 0 should not change the order + hits = rescorer.rescore(searcher, hits, 0); + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id")); + assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id")); + assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id")); + + // test rerank with different topN cuts + + for (int topN = 1; topN <= 5; topN++) { + logger.info("rerank {} documents ", topN); + hits = searcher.search(bqBuilder.build(), 10); + // meta = new MockModel(); + // rescorer = new LTRRescorer(new ModelQuery(meta)); + ScoreDoc[] slice = new ScoreDoc[topN]; + System.arraycopy(hits.scoreDocs, 0, slice, 0, topN); + hits = new TopDocs(hits.totalHits, slice, hits.getMaxScore()); + hits = rescorer.rescore(searcher, hits, topN); + for (int i = topN - 1, j = 0; i >= 0; i--, j++) { + logger.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc) + .get("id"), j); + + assertEquals(i, + Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id"))); + assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001); + + } + } + + r.close(); + dir.close(); + + } + + @Test + public void testDocParam() throws Exception { + NamedParams test = new NamedParams(); + test.add("fake", 2); + List features = makeFieldValueFeatures(new int[] {0}, + "final-score"); + List allFeatures = makeFieldValueFeatures(new int[] {0}, + "final-score"); + MockModel meta = new MockModel("test", MockModel.class.getCanonicalName(), + features, "test", allFeatures, null); + ModelQuery query = new ModelQuery(meta); + ModelWeight wgt = query.createWeight(null, true); + ModelScorer modelScr = wgt.scorer(null); + modelScr.setDocInfoParam("ORIGINAL_SCORE", 1); + for (ChildScorer feat : modelScr.getChildren()) { + assert (((FeatureScorer) feat.child).hasDocParam("ORIGINAL_SCORE")); + } + + features = makeFieldValueFeatures(new int[] {0, 1, 2}, "final-score"); + allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9}, "final-score"); + meta = new MockModel("test", MockModel.class.getCanonicalName(), features, + "test", allFeatures, null); + query = new ModelQuery(meta); + wgt = query.createWeight(null, true); + modelScr = wgt.scorer(null); + modelScr.setDocInfoParam("ORIGINAL_SCORE", 1); + for (ChildScorer feat : modelScr.getChildren()) { + assert (((FeatureScorer) feat.child).hasDocParam("ORIGINAL_SCORE")); + } + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/rest/TestModelManager.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/rest/TestModelManager.java new file mode 100644 index 000000000000..3fd4bbc55b05 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/rest/TestModelManager.java @@ -0,0 +1,170 @@ +package org.apache.solr.ltr.rest; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.solr.SolrTestCaseJ4.SuppressSSL; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.ranking.LTRComponent; +import org.apache.solr.ltr.ranking.LTRComponent.LTRParams; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceStorage; +import org.apache.solr.rest.RestManager; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +@SuppressSSL +public class TestModelManager extends TestRerankBase { + + @BeforeClass + public static void init() throws Exception { + setuptest(); + } + + @Before + public void restart() throws Exception { + restTestHarness.delete(LTRParams.MSTORE_END_POINT + "/*"); + restTestHarness.delete(LTRParams.FSTORE_END_POINT + "/*"); + + } + + @Test + public void test() throws Exception { + SolrResourceLoader loader = new SolrResourceLoader(tmpSolrHome.toPath()); + + RestManager.Registry registry = loader.getManagedResourceRegistry(); + assertNotNull( + "Expected a non-null RestManager.Registry from the SolrResourceLoader!", + registry); + + String resourceId = "/schema/fstore1"; + registry.registerManagedResource(resourceId, ManagedFeatureStore.class, + new LTRComponent()); + + String resourceId2 = "/schema/mstore1"; + registry.registerManagedResource(resourceId2, ManagedModelStore.class, + new LTRComponent()); + + NamedList initArgs = new NamedList<>(); + + RestManager restManager = new RestManager(); + restManager.init(loader, initArgs, + new ManagedResourceStorage.InMemoryStorageIO()); + + ManagedResource res = restManager.getManagedResource(resourceId); + assertTrue(res instanceof ManagedFeatureStore); + assertEquals(res.getResourceId(), resourceId); + + } + + @Test + public void testRestManagerEndpoints() throws Exception { + // relies on these ManagedResources being activated in the + // schema-rest.xml used by this test + assertJQ("/schema/managed", "/responseHeader/status==0"); + System.out.println(restTestHarness.query("/schema/managed")); + + System.out.println("after: \n" + restTestHarness.query("/schema/managed")); + + // Add features + String feature = "{\"name\": \"test1\", \"type\": \"org.apache.solr.ltr.feature.impl.ValueFeature\", \"params\": {\"value\": 1} }"; + assertJPut(TestRerankBase.FEATURE_ENDPOINT, feature, + "/responseHeader/status==0"); + + feature = "{\"name\": \"test2\", \"type\": \"org.apache.solr.ltr.feature.impl.ValueFeature\", \"params\": {\"value\": 1} }"; + assertJPut(TestRerankBase.FEATURE_ENDPOINT, feature, + "/responseHeader/status==0"); + + feature = "{\"name\": \"test3\", \"type\": \"org.apache.solr.ltr.feature.impl.ValueFeature\", \"params\": {\"value\": 1} }"; + assertJPut(TestRerankBase.FEATURE_ENDPOINT, feature, + "/responseHeader/status==0"); + + feature = "{\"name\": \"test33\", \"store\": \"TEST\", \"type\": \"org.apache.solr.ltr.feature.impl.ValueFeature\", \"params\": {\"value\": 1} }"; + assertJPut(TestRerankBase.FEATURE_ENDPOINT, feature, + "/responseHeader/status==0"); + + String multipleFeatures = "[{\"name\": \"test4\", \"type\": \"org.apache.solr.ltr.feature.impl.ValueFeature\", \"params\": {\"value\": 1} }" + + ",{\"name\": \"test5\", \"type\": \"org.apache.solr.ltr.feature.impl.ValueFeature\", \"params\": {\"value\": 1} } ]"; + assertJPut(TestRerankBase.FEATURE_ENDPOINT, multipleFeatures, + "/responseHeader/status==0"); + + // Add bad feature (wrong params)_ + String badfeature = "{\"name\": \"fvalue\", \"type\": \"org.apache.solr.ltr.feature.impl.FieldValueFeature\", \"params\": {\"value\": 1} }"; + assertJPut(TestRerankBase.FEATURE_ENDPOINT, badfeature, + "/responseHeader/status==400"); + + // Add models + String model = "{ \"name\":\"testmodel1\", \"type\":\"org.apache.solr.ltr.ranking.RankSVMModel\", \"features\":[] }"; + // fails since it does not have features + assertJPut(TestRerankBase.MODEL_ENDPOINT, model, + "/responseHeader/status==400"); + // fails since it does not have weights + model = "{ \"name\":\"testmodel2\", \"type\":\"org.apache.solr.ltr.ranking.RankSVMModel\", \"features\":[{\"name\":\"test1\"}, {\"name\":\"test2\"}] }"; + assertJPut(TestRerankBase.MODEL_ENDPOINT, model, + "/responseHeader/status==400"); + // success + model = "{ \"name\":\"testmodel3\", \"type\":\"org.apache.solr.ltr.ranking.RankSVMModel\", \"features\":[{\"name\":\"test1\"}, {\"name\":\"test2\"}],\"params\":{\"weights\":{\"test1\":1.5,\"test2\":2.0}}}"; + assertJPut(TestRerankBase.MODEL_ENDPOINT, model, + "/responseHeader/status==0"); + // success + String multipleModels = "[{ \"name\":\"testmodel4\", \"type\":\"org.apache.solr.ltr.ranking.RankSVMModel\", \"features\":[{\"name\":\"test1\"}, {\"name\":\"test2\"}],\"params\":{\"weights\":{\"test1\":1.5,\"test2\":2.0}} }\n" + + ",{ \"name\":\"testmodel5\", \"type\":\"org.apache.solr.ltr.ranking.RankSVMModel\", \"features\":[{\"name\":\"test1\"}, {\"name\":\"test2\"}],\"params\":{\"weights\":{\"test1\":1.5,\"test2\":2.0}} } ]"; + assertJPut(TestRerankBase.MODEL_ENDPOINT, multipleModels, + "/responseHeader/status==0"); + String qryResult = JQ(LTRParams.MSTORE_END_POINT); + + assert (qryResult.contains("\"name\":\"testmodel3\"") + && qryResult.contains("\"name\":\"testmodel4\"") && qryResult + .contains("\"name\":\"testmodel5\"")); + /* + * assertJQ(LTRParams.MSTORE_END_POINT, "/models/[0]/name=='testmodel3'"); + * assertJQ(LTRParams.MSTORE_END_POINT, "/models/[1]/name=='testmodel4'"); + * assertJQ(LTRParams.MSTORE_END_POINT, "/models/[2]/name=='testmodel5'"); + */ + assertJQ(LTRParams.FSTORE_END_POINT, "/featureStores==['TEST','_DEFAULT_']"); + assertJQ(LTRParams.FSTORE_END_POINT + "/_DEFAULT_", + "/features/[0]/name=='test1'"); + assertJQ(LTRParams.FSTORE_END_POINT + "/TEST", + "/features/[0]/name=='test33'"); + } + + @Test + public void testEndpointsFromFile() throws Exception { + loadFeatures("features-ranksvm.json"); + loadModels("ranksvm-model.json"); + + assertJQ(LTRParams.MSTORE_END_POINT, + "/models/[0]/name=='6029760550880411648'"); + assertJQ(LTRParams.FSTORE_END_POINT + "/_DEFAULT_", + "/features/[1]/name=='description'"); + } + + @Test + public void testLoadInvalidFeature() throws Exception { + // relies on these ManagedResources being activated in the + // schema-rest.xml used by this test + assertJQ("/schema/managed", "/responseHeader/status==0"); + String newEndpoint = LTRParams.FSTORE_END_POINT; + String feature = "{\"name\": \"^&test1\", \"type\": \"org.apache.solr.ltr.feature.impl.ValueFeature\", \"params\": {\"value\": 1} }"; + assertJPut(newEndpoint, feature, "/responseHeader/status==400"); + + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/rest/TestModelManagerPersistence.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/rest/TestModelManagerPersistence.java new file mode 100644 index 000000000000..6a528783bb20 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/rest/TestModelManagerPersistence.java @@ -0,0 +1,86 @@ +package org.apache.solr.ltr.rest; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.commons.io.FileUtils; +import org.apache.solr.SolrTestCaseJ4.SuppressSSL; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.impl.ValueFeature; +import org.apache.solr.ltr.ranking.LTRComponent.LTRParams; +import org.apache.solr.ltr.ranking.RankSVMModel; +import org.junit.Before; +import org.junit.Test; + +@SuppressSSL +public class TestModelManagerPersistence extends TestRerankBase { + + @Before + public void init() throws Exception { + setupPersistenttest(); + } + + // executed first + @Test + public void testFeaturePersistence() throws Exception { + + loadFeature("feature", ValueFeature.class.getCanonicalName(), "test", + "{\"value\":2}"); + System.out.println(restTestHarness.query(LTRParams.FSTORE_END_POINT + + "/test")); + assertJQ(LTRParams.FSTORE_END_POINT + "/test", + "/features/[0]/name=='feature'"); + restTestHarness.reload(); + assertJQ(LTRParams.FSTORE_END_POINT + "/test", + "/features/[0]/name=='feature'"); + loadFeature("feature1", ValueFeature.class.getCanonicalName(), "test1", + "{\"value\":2}"); + loadFeature("feature2", ValueFeature.class.getCanonicalName(), "test", + "{\"value\":2}"); + loadFeature("feature3", ValueFeature.class.getCanonicalName(), "test2", + "{\"value\":2}"); + assertJQ(LTRParams.FSTORE_END_POINT + "/test", + "/features/[0]/name=='feature'"); + assertJQ(LTRParams.FSTORE_END_POINT + "/test", + "/features/[1]/name=='feature2'"); + assertJQ(LTRParams.FSTORE_END_POINT + "/test1", + "/features/[0]/name=='feature1'"); + assertJQ(LTRParams.FSTORE_END_POINT + "/test2", + "/features/[0]/name=='feature3'"); + restTestHarness.reload(); + assertJQ(LTRParams.FSTORE_END_POINT + "/test", + "/features/[0]/name=='feature'"); + assertJQ(LTRParams.FSTORE_END_POINT + "/test", + "/features/[1]/name=='feature2'"); + assertJQ(LTRParams.FSTORE_END_POINT + "/test1", + "/features/[0]/name=='feature1'"); + assertJQ(LTRParams.FSTORE_END_POINT + "/test2", + "/features/[0]/name=='feature3'"); + loadModel("test-model", RankSVMModel.class.getCanonicalName(), + new String[] {"feature"}, "test", "{\"weights\":{\"feature\":1.0}}"); + String fstorecontent = FileUtils.readFileToString(fstorefile,"UTF-8"); + String mstorecontent = FileUtils.readFileToString(mstorefile,"UTF-8"); + + System.out.println("fstore:\n"); + System.out.println(fstorecontent); + + System.out.println("mstore:\n"); + System.out.println(mstorecontent); + + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/util/TestMacroExpander.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/util/TestMacroExpander.java new file mode 100644 index 000000000000..ffbd6dc2178c --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/util/TestMacroExpander.java @@ -0,0 +1,58 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import static org.junit.Assert.assertEquals; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +public class TestMacroExpander { + + @Test + public void testEmptyExpander() { + Map efi = new HashMap(); + MacroExpander macroExpander = new MacroExpander(efi); + + assertEquals("", macroExpander.expand("")); + assertEquals("foo", macroExpander.expand("foo")); + assertEquals("$foo", macroExpander.expand("$foo")); + assertEquals("${foo}", macroExpander.expand("${foo}")); + assertEquals("{foo}", macroExpander.expand("{foo}")); + assertEquals("${foo}", MacroExpander.expand("${foo}", efi)); + } + + @Test + public void testExpander() { + Map efi = new HashMap(); + efi.put("foo", "bar"); + efi.put("baz", "bat"); + MacroExpander macroExpander = new MacroExpander(efi); + + assertEquals("", macroExpander.expand("")); + assertEquals("foo", macroExpander.expand("foo")); + assertEquals("$foo", macroExpander.expand("$foo")); + assertEquals("bar", macroExpander.expand("${foo}")); + assertEquals("{foo}", macroExpander.expand("{foo}")); + assertEquals("bar", MacroExpander.expand("${foo}", efi)); + assertEquals("foo bar baz bat", + macroExpander.expand("foo ${foo} baz ${baz}")); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/util/TestNameValidator.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/util/TestNameValidator.java new file mode 100644 index 000000000000..c54bb30b1d32 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/util/TestNameValidator.java @@ -0,0 +1,48 @@ +package org.apache.solr.ltr.util; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class TestNameValidator { + + @Test + public void testValidator() { + assertTrue(NameValidator.check("test")); + assertTrue(NameValidator.check("constant")); + assertTrue(NameValidator.check("test_test")); + assertTrue(NameValidator.check("TEst")); + assertTrue(NameValidator.check("TEST")); + assertTrue(NameValidator.check("328195082960784")); + assertFalse(NameValidator.check(" ")); + assertFalse(NameValidator.check("")); + assertFalse(NameValidator.check("test?")); + assertFalse(NameValidator.check("??????")); + assertFalse(NameValidator.check("_____-----")); + assertFalse(NameValidator.check("12345,67890.31")); + assertFalse(NameValidator.check("aasdasdadasdzASADADSAZ01239()[]|_-")); + assertFalse(NameValidator.check(null)); + assertTrue(NameValidator.check("a")); + assertTrue(NameValidator.check("test()")); + assertTrue(NameValidator.check("test________123")); + + } +}