From e183affdd060f07b58f582c61acef22a0371a4b9 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 8 Jul 2015 20:59:07 +0300 Subject: [PATCH 1/2] Feature: support for per user api keys --- migrations/0009_add_api_key_to_user.py | 27 ++++++++ redash/authentication.py | 92 ++++++++++++-------------- redash/controllers.py | 30 +++++---- redash/models.py | 12 ++++ redash/utils.py | 9 +++ redash/wsgi.py | 6 +- requirements.txt | 4 +- tests/test_authentication.py | 64 +++++++++++++++--- tests/test_controllers.py | 52 --------------- 9 files changed, 167 insertions(+), 129 deletions(-) create mode 100644 migrations/0009_add_api_key_to_user.py diff --git a/migrations/0009_add_api_key_to_user.py b/migrations/0009_add_api_key_to_user.py new file mode 100644 index 0000000000..4df4cd1de3 --- /dev/null +++ b/migrations/0009_add_api_key_to_user.py @@ -0,0 +1,27 @@ +from playhouse.migrate import PostgresqlMigrator, migrate + +from redash.models import db +from redash import models + +if __name__ == '__main__': + db.connect_db() + migrator = PostgresqlMigrator(db.database) + + with db.database.transaction(): + column = models.User.api_key + column.null = True + migrate( + migrator.add_column('users', 'api_key', models.User.api_key), + ) + + for user in models.User.select(): + user.save() + + migrate( + migrator.add_not_null('users', 'api_key') + ) + + db.close_db(None) + + + diff --git a/redash/authentication.py b/redash/authentication.py index 6eebff15bf..1054eea414 100644 --- a/redash/authentication.py +++ b/redash/authentication.py @@ -1,11 +1,9 @@ -import functools import hashlib import hmac import time import logging -from flask import request, make_response, redirect, url_for -from flask.ext.login import LoginManager, login_user, current_user, logout_user +from flask.ext.login import LoginManager from redash import models, settings, google_oauth, saml_auth @@ -23,78 +21,72 @@ def sign(key, path, expires): return h.hexdigest() -class Authentication(object): - def verify_authentication(self): - return False - - def required(self, fn): - @functools.wraps(fn) - def decorated(*args, **kwargs): - if current_user.is_authenticated() or self.verify_authentication(): - return fn(*args, **kwargs) - - return make_response(redirect(url_for("login", next=request.url))) - - return decorated - - -class ApiKeyAuthentication(Authentication): - def verify_authentication(self): - api_key = request.args.get('api_key') - query_id = request.view_args.get('query_id', None) - - if query_id and api_key: - query = models.Query.get(models.Query.id == query_id) +@login_manager.user_loader +def load_user(user_id): + return models.User.get_by_id(user_id) - if query.api_key and api_key == query.api_key: - login_user(models.ApiUser(query.api_key), remember=False) - return True - return False +def hmac_load_user_from_request(request): + signature = request.args.get('signature') + expires = float(request.args.get('expires') or 0) + query_id = request.view_args.get('query_id', None) + user_id = request.args.get('user_id', None) + # TODO: 3600 should be a setting + if signature and time.time() < expires <= time.time() + 3600: + if user_id: + user = models.User.get_by_id(user_id) + calculated_signature = sign(user.api_key, request.path, expires) -class HMACAuthentication(Authentication): - def verify_authentication(self): - signature = request.args.get('signature') - expires = float(request.args.get('expires') or 0) - query_id = request.view_args.get('query_id', None) + if user.api_key and signature == calculated_signature: + return user - # TODO: 3600 should be a setting - if signature and query_id and time.time() < expires <= time.time() + 3600: + if query_id: query = models.Query.get(models.Query.id == query_id) calculated_signature = sign(query.api_key, request.path, expires) if query.api_key and signature == calculated_signature: - login_user(models.ApiUser(query.api_key), remember=False) - return True - - return False + return models.ApiUser(query.api_key) + return None -@login_manager.user_loader -def load_user(user_id): - # If the user was previously logged in as api user, the user_id will be the api key and will raise an exception as - # it can't be casted to int. - if isinstance(user_id, basestring) and not user_id.isdigit(): +def get_user_from_api_key(api_key, query_id): + if not api_key: return None - return models.User.select().where(models.User.id == user_id).first() + user = None + try: + user = models.User.get_by_api_key(api_key) + except models.User.DoesNotExist: + if query_id: + query = models.Query.get_by_id(query_id) + if query and query.api_key == api_key: + user = models.ApiUser(api_key) + + return user + +def api_key_load_user_from_request(request): + api_key = request.args.get('api_key', None) + query_id = request.view_args.get('query_id', None) + + user = get_user_from_api_key(api_key, query_id) + return user def setup_authentication(app): login_manager.init_app(app) login_manager.anonymous_user = models.AnonymousUser + login_manager.login_view = 'login' app.secret_key = settings.COOKIE_SECRET app.register_blueprint(google_oauth.blueprint) app.register_blueprint(saml_auth.blueprint) if settings.AUTH_TYPE == 'hmac': - auth = HMACAuthentication() + login_manager.request_loader(hmac_load_user_from_request) elif settings.AUTH_TYPE == 'api_key': - auth = ApiKeyAuthentication() + login_manager.request_loader(api_key_load_user_from_request) else: logger.warning("Unknown authentication type ({}). Using default (HMAC).".format(settings.AUTH_TYPE)) - auth = HMACAuthentication() + login_manager.request_loader(hmac_load_user_from_request) - return auth diff --git a/redash/controllers.py b/redash/controllers.py index 69c1269741..6ba4737cb5 100644 --- a/redash/controllers.py +++ b/redash/controllers.py @@ -14,11 +14,11 @@ from flask import render_template, send_from_directory, make_response, request, jsonify, redirect, \ session, url_for, current_app from flask.ext.restful import Resource, abort -from flask_login import current_user, login_user, logout_user +from flask_login import current_user, login_user, logout_user, login_required import sqlparse -from redash import redis_connection, statsd_client, models, settings, utils, __version__ -from redash.wsgi import app, auth, api +from redash import statsd_client, models, settings, utils +from redash.wsgi import app, api from redash.tasks import QueryTask, record_event from redash.cache import headers as cache_headers from redash.permissions import require_permission @@ -38,7 +38,7 @@ def ping(): @app.route('/queries//') @app.route('/personal') @app.route('/') -@auth.required +@login_required def index(**kwargs): email_md5 = hashlib.md5(current_user.email.lower()).hexdigest() gravatar_url = "https://www.gravatar.com/avatar/%s?s=40" % email_md5 @@ -72,13 +72,15 @@ def login(): else: return redirect(url_for("google_oauth.authorize", next=request.args.get('next'))) - if request.method == 'POST': - user = models.User.select().where(models.User.email == request.form['username']).first() - if user and user.verify_password(request.form['password']): - remember = ('remember' in request.form) - login_user(user, remember=remember) - return redirect(request.args.get('next') or '/') + try: + user = models.User.get_by_email(request.form['username']) + if user and user.verify_password(request.form['password']): + remember = ('remember' in request.form) + login_user(user, remember=remember) + return redirect(request.args.get('next') or '/') + except models.User.DoesNotExist: + pass return render_template("login.html", name=settings.NAME, @@ -96,7 +98,7 @@ def logout(): return redirect('/login') @app.route('/status.json') -@auth.required +@login_required @require_permission('admin') def status_api(): status = get_status() @@ -105,7 +107,7 @@ def status_api(): @app.route('/api/queries/format', methods=['POST']) -@auth.required +@login_required def format_sql_query(): arguments = request.get_json(force=True) query = arguments.get("query", "") @@ -114,7 +116,7 @@ def format_sql_query(): @app.route('/queries/new', methods=['POST']) -@auth.required +@login_required def create_query_route(): query = request.form.get('query', None) data_source_id = request.form.get('data_source_id', None) @@ -132,7 +134,7 @@ def create_query_route(): class BaseResource(Resource): - decorators = [auth.required] + decorators = [login_required] def __init__(self, *args, **kwargs): super(BaseResource, self).__init__(*args, **kwargs) diff --git a/redash/models.py b/redash/models.py index dceb0ab2ab..b28a357ea8 100644 --- a/redash/models.py +++ b/redash/models.py @@ -15,6 +15,7 @@ from redash import utils, settings, redis_connection from redash.query_runner import get_query_runner +from utils import generate_token class Database(object): @@ -152,6 +153,7 @@ class User(ModelTimestampsMixin, BaseModel, UserMixin, PermissionsCheckMixin): email = peewee.CharField(max_length=320, index=True, unique=True) password_hash = peewee.CharField(max_length=128, null=True) groups = ArrayField(peewee.CharField, default=DEFAULT_GROUPS) + api_key = peewee.CharField(max_length=40) class Meta: db_table = 'users' @@ -169,6 +171,12 @@ def __init__(self, *args, **kwargs): super(User, self).__init__(*args, **kwargs) self._allowed_tables = None + def pre_save(self, created): + super(User, self).pre_save(created) + + if not self.api_key: + self.api_key = generate_token(40) + @property def permissions(self): # TODO: this should be cached. @@ -188,6 +196,10 @@ def allowed_tables(self): def get_by_email(cls, email): return cls.get(cls.email == email) + @classmethod + def get_by_api_key(cls, api_key): + return cls.get(cls.api_key == api_key) + def __unicode__(self): return '%r, %r' % (self.name, self.email) diff --git a/redash/utils.py b/redash/utils.py index 41d23372f2..59faeac614 100644 --- a/redash/utils.py +++ b/redash/utils.py @@ -4,6 +4,7 @@ import decimal import datetime import json +import random import re import hashlib import sqlparse @@ -88,6 +89,14 @@ def gen_query_hash(sql): return hashlib.md5(sql.encode('utf-8')).hexdigest() +def generate_token(length): + chars = ('abcdefghijklmnopqrstuvwxyz' + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + '0123456789') + + rand = random.SystemRandom() + return ''.join(rand.choice(chars) for x in range(length)) + class JSONEncoder(json.JSONEncoder): """Custom JSON encoding class, to handle Decimal and datetime.date instances. """ diff --git a/redash/wsgi.py b/redash/wsgi.py index c775362f86..56d2180b1b 100644 --- a/redash/wsgi.py +++ b/redash/wsgi.py @@ -1,5 +1,6 @@ import json from flask import Flask, make_response +from werkzeug.wrappers import Response from flask.ext.restful import Api from redash import settings, utils @@ -24,10 +25,13 @@ db.init_app(app) from redash.authentication import setup_authentication -auth = setup_authentication(app) +setup_authentication(app) @api.representation('application/json') def json_representation(data, code, headers=None): + # Flask-Restful checks only for flask.Response but flask-login uses werkzeug.wrappers.Response + if isinstance(data, Response): + return data resp = make_response(json.dumps(data, cls=utils.JSONEncoder), code) resp.headers.extend(headers or {}) return resp diff --git a/requirements.txt b/requirements.txt index d901a9b82d..582aa79f1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ Flask==0.10.1 Flask-Admin==1.1.0 Flask-RESTful==0.2.10 -Flask-Login==0.2.9 +Flask-Login==0.2.11 Flask-OAuth==0.12 passlib==1.6.2 Jinja2==2.7.2 @@ -29,4 +29,4 @@ click==3.3 RestrictedPython==3.6.0 wtf-peewee==0.2.3 pysaml2==2.4.0 -pycrypto==2.6.1 \ No newline at end of file +pycrypto==2.6.1 diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 8d2e7a947e..9462e49c65 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,9 +1,10 @@ -from flask.ext.login import current_user +from flask import request from mock import patch +import time from tests import BaseTestCase from redash import models from redash.google_oauth import create_and_login_user -from redash.authentication import ApiKeyAuthentication +from redash.authentication import api_key_load_user_from_request, hmac_load_user_from_request, sign from tests.factories import user_factory, query_factory from redash.wsgi import app @@ -18,29 +19,72 @@ def setUp(self): self.query = query_factory.create(api_key=self.api_key) def test_no_api_key(self): - auth = ApiKeyAuthentication() with app.test_client() as c: rv = c.get('/api/queries/{0}'.format(self.query.id)) - self.assertFalse(auth.verify_authentication()) + self.assertIsNone(api_key_load_user_from_request(request)) def test_wrong_api_key(self): - auth = ApiKeyAuthentication() with app.test_client() as c: rv = c.get('/api/queries/{0}'.format(self.query.id), query_string={'api_key': 'whatever'}) - self.assertFalse(auth.verify_authentication()) + self.assertIsNone(api_key_load_user_from_request(request)) def test_correct_api_key(self): - auth = ApiKeyAuthentication() with app.test_client() as c: rv = c.get('/api/queries/{0}'.format(self.query.id), query_string={'api_key': self.api_key}) - self.assertTrue(auth.verify_authentication()) + self.assertIsNotNone(api_key_load_user_from_request(request)) def test_no_query_id(self): - auth = ApiKeyAuthentication() with app.test_client() as c: rv = c.get('/api/queries', query_string={'api_key': self.api_key}) - self.assertFalse(auth.verify_authentication()) + self.assertIsNone(api_key_load_user_from_request(request)) + def test_user_api_key(self): + user = user_factory.create(api_key="user_key") + with app.test_client() as c: + rv = c.get('/api/queries/', query_string={'api_key': user.api_key}) + self.assertEqual(user.id, api_key_load_user_from_request(request).id) + +class TestHMACAuthentication(BaseTestCase): + # + # This is a bad way to write these tests, but the way Flask works doesn't make it easy to write them properly... + # + def setUp(self): + super(TestHMACAuthentication, self).setUp() + self.api_key = 10 + self.query = query_factory.create(api_key=self.api_key) + self.path = '/api/queries/{0}'.format(self.query.id) + self.expires = time.time() + 1800 + + def signature(self, expires): + return sign(self.query.api_key, self.path, expires) + + def test_no_signature(self): + with app.test_client() as c: + rv = c.get(self.path) + self.assertIsNone(hmac_load_user_from_request(request)) + + def test_wrong_signature(self): + with app.test_client() as c: + rv = c.get(self.path, query_string={'signature': 'whatever', 'expires': self.expires}) + self.assertIsNone(hmac_load_user_from_request(request)) + + def test_correct_signature(self): + with app.test_client() as c: + rv = c.get('/api/queries/{0}'.format(self.query.id), query_string={'signature': self.signature(self.expires), 'expires': self.expires}) + self.assertIsNotNone(hmac_load_user_from_request(request)) + + def test_no_query_id(self): + with app.test_client() as c: + rv = c.get('/api/queries', query_string={'api_key': self.api_key}) + self.assertIsNone(hmac_load_user_from_request(request)) + + def test_user_api_key(self): + user = user_factory.create(api_key="user_key") + path = '/api/queries/' + with app.test_client() as c: + signature = sign(user.api_key, path, self.expires) + rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id}) + self.assertEqual(user.id, hmac_load_user_from_request(request).id) class TestCreateAndLoginUser(BaseTestCase): def test_logins_valid_user(self): diff --git a/tests/test_controllers.py b/tests/test_controllers.py index ce98ce8371..d4bd6b8ac7 100644 --- a/tests/test_controllers.py +++ b/tests/test_controllers.py @@ -337,58 +337,6 @@ def setUp(self): super(JobAPITest, self).setUp() -class CsvQueryResultAPITest(BaseTestCase, AuthenticationTestMixin): - def setUp(self): - super(CsvQueryResultAPITest, self).setUp() - - self.paths = [] - self.query_result = query_result_factory.create() - self.query = query_factory.create() - self.path = '/api/queries/{0}/results/{1}.csv'.format(self.query.id, self.query_result.id) - - # TODO: factor out the HMAC authentication tests - - def signature(self, expires): - return sign(self.query.api_key, self.path, expires) - - def test_redirect_when_unauthenticated(self): - with app.test_client() as c: - rv = c.get(self.path) - self.assertEquals(rv.status_code, 302) - - def test_redirect_for_wrong_signature(self): - with app.test_client() as c: - rv = c.get('/api/queries/{0}/results/{1}.csv'.format(self.query.id, self.query_result.id), query_string={'signature': 'whatever', 'expires': 0}) - self.assertEquals(rv.status_code, 302) - - def test_redirect_for_correct_signature_and_wrong_expires(self): - with app.test_client() as c: - rv = c.get('/api/queries/{0}/results/{1}.csv'.format(self.query.id, self.query_result.id), query_string={'signature': self.signature(0), 'expires': 0}) - self.assertEquals(rv.status_code, 302) - - def test_redirect_for_correct_signature_and_no_expires(self): - with app.test_client() as c: - rv = c.get('/api/queries/{0}/results/{1}.csv'.format(self.query.id, self.query_result.id), query_string={'signature': self.signature(time.time()+3600)}) - self.assertEquals(rv.status_code, 302) - - def test_redirect_for_correct_signature_and_expires_too_long(self): - with app.test_client() as c: - expires = time.time()+(10*3600) - rv = c.get('/api/queries/{0}/results/{1}.csv'.format(self.query.id, self.query_result.id), query_string={'signature': self.signature(expires), 'expires': expires}) - self.assertEquals(rv.status_code, 302) - - def test_returns_200_for_correct_signature(self): - with app.test_client() as c: - expires = time.time()+1800 - rv = c.get('/api/queries/{0}/results/{1}.csv'.format(self.query.id, self.query_result.id), query_string={'signature': self.signature(expires), 'expires': expires}) - self.assertEquals(rv.status_code, 200) - - def test_returns_200_for_authenticated_user(self): - with app.test_client() as c, authenticated_user(c): - rv = c.get('/api/queries/{0}/results/{1}.csv'.format(self.query.id, self.query_result.id)) - self.assertEquals(rv.status_code, 200) - - class TestLogin(BaseTestCase): def setUp(self): settings.PASSWORD_LOGIN_ENABLED = True From 6860dde1f72428e661d8a40340268e318169805b Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 8 Jul 2015 21:29:32 +0300 Subject: [PATCH 2/2] Set api_key to be unique --- redash/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redash/models.py b/redash/models.py index b28a357ea8..3f9a43daf7 100644 --- a/redash/models.py +++ b/redash/models.py @@ -153,7 +153,7 @@ class User(ModelTimestampsMixin, BaseModel, UserMixin, PermissionsCheckMixin): email = peewee.CharField(max_length=320, index=True, unique=True) password_hash = peewee.CharField(max_length=128, null=True) groups = ArrayField(peewee.CharField, default=DEFAULT_GROUPS) - api_key = peewee.CharField(max_length=40) + api_key = peewee.CharField(max_length=40, unique=True) class Meta: db_table = 'users'