From b794b2a6c1dd571a6f553a940b29d4cdd48a1c68 Mon Sep 17 00:00:00 2001 From: Alexis Lahouze Date: Fri, 19 May 2017 00:07:30 +0200 Subject: [PATCH] Use Flask-JWT-extended. --- accountant/__init__.py | 5 +- accountant/views/__init__.py | 27 +++++++++++ accountant/views/accounts.py | 21 ++++---- accountant/views/operations.py | 13 +++-- accountant/views/scheduled_operations.py | 12 ++--- accountant/views/users.py | 62 ++++++------------------ setup.py | 1 + 7 files changed, 69 insertions(+), 72 deletions(-) diff --git a/accountant/__init__.py b/accountant/__init__.py index 6de668f..aed045d 100644 --- a/accountant/__init__.py +++ b/accountant/__init__.py @@ -10,7 +10,7 @@ from flask_alembic import Alembic from flask_alembic import alembic_click from .models import db -from .views import api, cors +from .views import api, cors, jwt # pylint: disable=invalid-name alembic = Alembic() @@ -39,6 +39,9 @@ def create_app(config_path): # API views related stuff. cors.init_app(app) api.init_app(app) + jwt.init_app(app) + # Needed to handle correctly JWT exceptions by restplus. + jwt._set_error_handler_callbacks(api) # pylint: disable=protected-access return app diff --git a/accountant/views/__init__.py b/accountant/views/__init__.py index 34e2550..ae607f6 100644 --- a/accountant/views/__init__.py +++ b/accountant/views/__init__.py @@ -3,6 +3,7 @@ # vim: set tw=80 ts=4 sw=4 sts=4: from flask_cors import CORS +from flask_jwt_extended import JWTManager from flask_restplus import Api from .accounts import ns as accounts_ns @@ -40,3 +41,29 @@ api.add_namespace(users_ns) # pylint: disable=invalid-name cors = CORS() + +jwt = JWTManager() + + +@jwt.user_identity_loader +def user_identity_lookup(user): + """Return information to be in token.""" + return user.id + + +@jwt.expired_token_loader +def expired_token_callback(): + """Handler for expired token.""" + api.abort(401, "Expired token.") + + +@jwt.unauthorized_loader +def unauthorized_callback(message): + """Handler for unauthorized.""" + api.abort(401, message) + + +@jwt.invalid_token_loader +def invalid_token_callback(message): + """Handler for invalid token.""" + api.abort(401, message) diff --git a/accountant/views/accounts.py b/accountant/views/accounts.py index a355ea6..7896191 100644 --- a/accountant/views/accounts.py +++ b/accountant/views/accounts.py @@ -4,14 +4,13 @@ import dateutil.parser +from flask_jwt_extended import jwt_required from flask_restplus import Namespace, Resource, fields from ..models import db from ..models.accounts import Account from ..models.operations import Operation -from .users import requires_auth - # pylint: disable=invalid-name ns = Namespace('account', description='Account management') @@ -128,19 +127,19 @@ range_parser.add_argument( class AccountListResource(Resource): """Resource used to handle account lists.""" - @requires_auth @ns.response(200, 'OK', [account_model]) @ns.marshal_list_with(account_model) + @jwt_required def get(self): """ Returns accounts with their balances.""" return Account.query().all(), 200 - @requires_auth @ns.expect(account_model) @ns.response(201, 'Account created', account_model) @ns.response(406, 'Invalid account data') @ns.marshal_with(account_model) + @jwt_required def post(self): """Create a new account.""" @@ -179,9 +178,9 @@ class AccountListResource(Resource): class AccountResource(Resource): """Resource to handle accounts.""" - @requires_auth @ns.response(200, 'OK', account_model) @ns.marshal_with(account_model) + @jwt_required def get(self, id): """Get an account.""" @@ -197,11 +196,11 @@ class AccountResource(Resource): # causes error on marshalling. return account, 200 - @requires_auth @ns.expect(account_model) @ns.response(200, 'OK', account_model) @ns.response(406, 'Invalid account data') @ns.marshal_with(account_model) + @jwt_required def post(self, id): """Update an account.""" @@ -232,9 +231,9 @@ class AccountResource(Resource): # Return account. return account, 200 - @requires_auth @ns.response(204, 'Account deleted', account_model) @ns.marshal_with(account_model) + @jwt_required def delete(self, id): """Delete an account.""" @@ -256,7 +255,6 @@ class AccountResource(Resource): class SoldsResource(Resource): """Resource to expose solds.""" - @requires_auth @ns.doc( security='apikey', responses={ @@ -265,6 +263,7 @@ class SoldsResource(Resource): 404: 'Account not found' }) @ns.marshal_with(solds_model) + @jwt_required def get(self, id): """Get solds for a specific account and date range.""" @@ -285,7 +284,6 @@ class SoldsResource(Resource): class BalanceResource(Resource): """Resource to expose balances.""" - @requires_auth @ns.doc( security='apikey', responses={ @@ -295,6 +293,7 @@ class BalanceResource(Resource): }) @ns.expect(range_parser) @ns.marshal_with(balance_model) + @jwt_required def get(self, id): """Get account balance for a specific date range.""" @@ -317,7 +316,6 @@ class BalanceResource(Resource): class CategoryResource(Resource): """Resource to expose categories.""" - @requires_auth @ns.doc( security='apikey', responses={ @@ -327,6 +325,7 @@ class CategoryResource(Resource): }) @ns.expect(range_parser) @ns.marshal_list_with(category_model) + @jwt_required def get(self, id): """Get account category balances for a specific date range.""" @@ -339,7 +338,6 @@ class CategoryResource(Resource): class OHLCResource(Resource): """Resource to expose OHLC.""" - @requires_auth @ns.doc( security='apikey', responses={ @@ -349,6 +347,7 @@ class OHLCResource(Resource): }) @ns.expect(range_parser) @ns.marshal_list_with(ohlc_model) + @jwt_required def get(self, id): """Get OHLC data for a specific date range and account.""" diff --git a/accountant/views/operations.py b/accountant/views/operations.py index e1e97bf..faaf91b 100644 --- a/accountant/views/operations.py +++ b/accountant/views/operations.py @@ -4,14 +4,13 @@ import dateutil.parser +from flask_jwt_extended import jwt_required from flask_restplus import Namespace, Resource, fields from ..models import db from ..models.accounts import Account from ..models.operations import Operation -from .users import requires_auth - # pylint: disable=invalid-name ns = Namespace('operation', description='Operation management') @@ -98,10 +97,10 @@ account_range_parser.add_argument( class OperationListResource(Resource): """Resource to handle operation lists.""" - @requires_auth @ns.response(200, 'OK', [operation_with_sold_model]) @ns.expect(parser=account_range_parser) @ns.marshal_list_with(operation_with_sold_model) + @jwt_required def get(self): """Get operations with solds for a specific account.""" @@ -114,11 +113,11 @@ class OperationListResource(Resource): Operation.account_id == data['account_id'] ).all(), 200 - @requires_auth @ns.response(201, 'Operation created', operation_model) @ns.response(404, 'Account not found') @ns.response(406, 'Invalid operation data') @ns.marshal_with(operation_model) + @jwt_required def post(self): """Create a new operation.""" @@ -160,9 +159,9 @@ class OperationListResource(Resource): class OperationResource(Resource): """Resource to handle operations.""" - @requires_auth @ns.response(200, 'OK', operation_model) @ns.marshal_with(operation_model) + @jwt_required def get(self, id): """Get operation.""" @@ -176,11 +175,11 @@ class OperationResource(Resource): return operation, 200 - @requires_auth @ns.expect(operation_model) @ns.response(200, 'OK', operation_model) @ns.response(406, 'Invalid operation data') @ns.marshal_with(operation_model) + @jwt_required def post(self, id): """Update an operation.""" @@ -211,9 +210,9 @@ class OperationResource(Resource): return operation, 200 - @requires_auth @ns.response(204, 'Operation deleted', operation_model) @ns.marshal_with(operation_model) + @jwt_required def delete(self, id): """Delete an operation.""" diff --git a/accountant/views/scheduled_operations.py b/accountant/views/scheduled_operations.py index 07cbf81..70b74d5 100644 --- a/accountant/views/scheduled_operations.py +++ b/accountant/views/scheduled_operations.py @@ -2,6 +2,7 @@ # vim: set tw=80 ts=4 sw=4 sts=4: +from flask_jwt_extended import jwt_required from flask_restplus import Namespace, Resource, fields from sqlalchemy import true @@ -11,7 +12,6 @@ from ..models.accounts import Account from ..models.operations import Operation from ..models.scheduled_operations import ScheduledOperation -from .users import requires_auth # pylint: disable=invalid-name ns = Namespace( @@ -75,10 +75,10 @@ account_id_parser.add_argument( class ScheduledOperationListResource(Resource): """Resource to handle scheduled operation lists.""" - @requires_auth @ns.expect(account_id_parser) @ns.response(200, 'OK', [scheduled_operation_model]) @ns.marshal_list_with(scheduled_operation_model) + @jwt_required def get(self): """Get all scheduled operation for an account.""" @@ -86,12 +86,12 @@ class ScheduledOperationListResource(Resource): return ScheduledOperation.query().filter_by(**data).all(), 200 - @requires_auth @ns.expect(scheduled_operation_model) @ns.response(200, 'OK', scheduled_operation_model) @ns.response(404, 'Account not found') @ns.response(406, 'Invalid operation data') @ns.marshal_with(scheduled_operation_model) + @jwt_required def post(self): """Add a new scheduled operation.""" @@ -137,9 +137,9 @@ class ScheduledOperationListResource(Resource): class ScheduledOperationResource(Resource): """Resource to handle scheduled operations.""" - @requires_auth @ns.response(200, 'OK', scheduled_operation_model) @ns.marshal_with(scheduled_operation_model) + @jwt_required def get(self, id): """Get scheduled operation.""" @@ -153,11 +153,11 @@ class ScheduledOperationResource(Resource): return scheduled_operation, 200 - @requires_auth @ns.response(200, 'OK', scheduled_operation_model) @ns.response(406, 'Invalid scheduled operation data') @ns.expect(scheduled_operation_model) @ns.marshal_with(scheduled_operation_model) + @jwt_required def post(self, id): """Update a scheduled operation.""" @@ -192,10 +192,10 @@ class ScheduledOperationResource(Resource): return scheduled_operation, 200 - @requires_auth @ns.response(200, 'OK', scheduled_operation_model) @ns.response(409, 'Cannot be deleted') @ns.marshal_with(scheduled_operation_model) + @jwt_required def delete(self, id): """Delete a scheduled operation.""" diff --git a/accountant/views/users.py b/accountant/views/users.py index bc8cbf1..e92f613 100644 --- a/accountant/views/users.py +++ b/accountant/views/users.py @@ -2,11 +2,9 @@ # vim: set tw=80 ts=4 sw=4 sts=4: -from functools import wraps - -import arrow - -from flask import request, g, current_app as app +from flask import request +from flask_jwt_extended import (jwt_required, get_jwt_identity, + create_access_token, create_refresh_token) from flask_restplus import Namespace, Resource, fields from ..models.users import User @@ -25,49 +23,24 @@ def load_user_from_auth(auth): return load_user_from_token(token) -def requires_auth(func): - """Decorator to check user authentication before handling a request.""" - - @wraps(func) - def wrapped(*args, **data): - """Check user authentication from requests headers and return 401 if - unauthorized.""" - - user = None - - if 'Authorization' in request.headers: - auth = request.headers['Authorization'] - user = load_user_from_auth(auth) - - if user: - g.user = user - return func(*args, **data) - - ns.abort( - 401, - error_message='Please login before executing this request.' - ) - return wrapped - - # pylint: disable=invalid-name ns = Namespace('user', description='User management') # Token with expiration time and type. token_model = ns.model('Token', { - 'token': fields.String( + 'access_token': fields.String( required=True, readonly=True, - description='Token value'), + description='Access token value'), + 'refresh_token': fields.String( + required=False, + readonly=True, + description='Refresh token value'), 'expiration': fields.DateTime( dt_format='iso8601', - required=True, + required=False, readonly=True, description='Expiration time of the token'), - 'token_type': fields.String( - required=True, - readonly=True, - description='Token type') }) # User model. @@ -112,25 +85,20 @@ class LoginResource(Resource): if not user or not user.verify_password(password): ns.abort(401, error_message="Bad user or password.") - token = user.generate_auth_token() - expiration_time = arrow.now().replace( - seconds=app.config['SESSION_TTL'] - ) - return { - 'token': token, - 'expiration': expiration_time.datetime, - 'token_type': 'Bearer' + 'access_token': create_access_token(identity=user), + 'refresh_token': create_refresh_token(identity=user) }, 200 - @requires_auth @ns.doc( security='apikey', responses={ 200: ('OK', user_model) }) @ns.marshal_with(user_model) + @jwt_required def get(self): """Get authenticated user information.""" + user = User.query().get(get_jwt_identity()) - return g.user, 200 + return user, 200 diff --git a/setup.py b/setup.py index dd2b2ce..493732c 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ setup( 'Flask-SQLAlchemy>=2.2', 'Flask-restplus>=0.10.1', 'Flask-Cors>=2.1.2', + 'Flask-JWT-Extended>=2.0.0', 'passlib>=1.7.1', 'arrow>=0.10.0', ],