Use Flask-JWT-extended.

This commit is contained in:
Alexis Lahouze 2017-05-19 00:07:30 +02:00
parent e64ea7b3d1
commit b794b2a6c1
7 changed files with 69 additions and 72 deletions

View File

@ -10,7 +10,7 @@ from flask_alembic import Alembic
from flask_alembic import alembic_click from flask_alembic import alembic_click
from .models import db from .models import db
from .views import api, cors from .views import api, cors, jwt
# pylint: disable=invalid-name # pylint: disable=invalid-name
alembic = Alembic() alembic = Alembic()
@ -39,6 +39,9 @@ def create_app(config_path):
# API views related stuff. # API views related stuff.
cors.init_app(app) cors.init_app(app)
api.init_app(app) api.init_app(app)
jwt.init_app(app)
# Needed to handle correctly JWT exceptions by restplus.
jwt._set_error_handler_callbacks(api) # pylint: disable=protected-access
return app return app

View File

@ -3,6 +3,7 @@
# vim: set tw=80 ts=4 sw=4 sts=4: # vim: set tw=80 ts=4 sw=4 sts=4:
from flask_cors import CORS from flask_cors import CORS
from flask_jwt_extended import JWTManager
from flask_restplus import Api from flask_restplus import Api
from .accounts import ns as accounts_ns from .accounts import ns as accounts_ns
@ -40,3 +41,29 @@ api.add_namespace(users_ns)
# pylint: disable=invalid-name # pylint: disable=invalid-name
cors = CORS() cors = CORS()
jwt = JWTManager()
@jwt.user_identity_loader
def user_identity_lookup(user):
"""Return information to be in token."""
return user.id
@jwt.expired_token_loader
def expired_token_callback():
"""Handler for expired token."""
api.abort(401, "Expired token.")
@jwt.unauthorized_loader
def unauthorized_callback(message):
"""Handler for unauthorized."""
api.abort(401, message)
@jwt.invalid_token_loader
def invalid_token_callback(message):
"""Handler for invalid token."""
api.abort(401, message)

View File

@ -4,14 +4,13 @@
import dateutil.parser import dateutil.parser
from flask_jwt_extended import jwt_required
from flask_restplus import Namespace, Resource, fields from flask_restplus import Namespace, Resource, fields
from ..models import db from ..models import db
from ..models.accounts import Account from ..models.accounts import Account
from ..models.operations import Operation from ..models.operations import Operation
from .users import requires_auth
# pylint: disable=invalid-name # pylint: disable=invalid-name
ns = Namespace('account', description='Account management') ns = Namespace('account', description='Account management')
@ -128,19 +127,19 @@ range_parser.add_argument(
class AccountListResource(Resource): class AccountListResource(Resource):
"""Resource used to handle account lists.""" """Resource used to handle account lists."""
@requires_auth
@ns.response(200, 'OK', [account_model]) @ns.response(200, 'OK', [account_model])
@ns.marshal_list_with(account_model) @ns.marshal_list_with(account_model)
@jwt_required
def get(self): def get(self):
""" Returns accounts with their balances.""" """ Returns accounts with their balances."""
return Account.query().all(), 200 return Account.query().all(), 200
@requires_auth
@ns.expect(account_model) @ns.expect(account_model)
@ns.response(201, 'Account created', account_model) @ns.response(201, 'Account created', account_model)
@ns.response(406, 'Invalid account data') @ns.response(406, 'Invalid account data')
@ns.marshal_with(account_model) @ns.marshal_with(account_model)
@jwt_required
def post(self): def post(self):
"""Create a new account.""" """Create a new account."""
@ -179,9 +178,9 @@ class AccountListResource(Resource):
class AccountResource(Resource): class AccountResource(Resource):
"""Resource to handle accounts.""" """Resource to handle accounts."""
@requires_auth
@ns.response(200, 'OK', account_model) @ns.response(200, 'OK', account_model)
@ns.marshal_with(account_model) @ns.marshal_with(account_model)
@jwt_required
def get(self, id): def get(self, id):
"""Get an account.""" """Get an account."""
@ -197,11 +196,11 @@ class AccountResource(Resource):
# causes error on marshalling. # causes error on marshalling.
return account, 200 return account, 200
@requires_auth
@ns.expect(account_model) @ns.expect(account_model)
@ns.response(200, 'OK', account_model) @ns.response(200, 'OK', account_model)
@ns.response(406, 'Invalid account data') @ns.response(406, 'Invalid account data')
@ns.marshal_with(account_model) @ns.marshal_with(account_model)
@jwt_required
def post(self, id): def post(self, id):
"""Update an account.""" """Update an account."""
@ -232,9 +231,9 @@ class AccountResource(Resource):
# Return account. # Return account.
return account, 200 return account, 200
@requires_auth
@ns.response(204, 'Account deleted', account_model) @ns.response(204, 'Account deleted', account_model)
@ns.marshal_with(account_model) @ns.marshal_with(account_model)
@jwt_required
def delete(self, id): def delete(self, id):
"""Delete an account.""" """Delete an account."""
@ -256,7 +255,6 @@ class AccountResource(Resource):
class SoldsResource(Resource): class SoldsResource(Resource):
"""Resource to expose solds.""" """Resource to expose solds."""
@requires_auth
@ns.doc( @ns.doc(
security='apikey', security='apikey',
responses={ responses={
@ -265,6 +263,7 @@ class SoldsResource(Resource):
404: 'Account not found' 404: 'Account not found'
}) })
@ns.marshal_with(solds_model) @ns.marshal_with(solds_model)
@jwt_required
def get(self, id): def get(self, id):
"""Get solds for a specific account and date range.""" """Get solds for a specific account and date range."""
@ -285,7 +284,6 @@ class SoldsResource(Resource):
class BalanceResource(Resource): class BalanceResource(Resource):
"""Resource to expose balances.""" """Resource to expose balances."""
@requires_auth
@ns.doc( @ns.doc(
security='apikey', security='apikey',
responses={ responses={
@ -295,6 +293,7 @@ class BalanceResource(Resource):
}) })
@ns.expect(range_parser) @ns.expect(range_parser)
@ns.marshal_with(balance_model) @ns.marshal_with(balance_model)
@jwt_required
def get(self, id): def get(self, id):
"""Get account balance for a specific date range.""" """Get account balance for a specific date range."""
@ -317,7 +316,6 @@ class BalanceResource(Resource):
class CategoryResource(Resource): class CategoryResource(Resource):
"""Resource to expose categories.""" """Resource to expose categories."""
@requires_auth
@ns.doc( @ns.doc(
security='apikey', security='apikey',
responses={ responses={
@ -327,6 +325,7 @@ class CategoryResource(Resource):
}) })
@ns.expect(range_parser) @ns.expect(range_parser)
@ns.marshal_list_with(category_model) @ns.marshal_list_with(category_model)
@jwt_required
def get(self, id): def get(self, id):
"""Get account category balances for a specific date range.""" """Get account category balances for a specific date range."""
@ -339,7 +338,6 @@ class CategoryResource(Resource):
class OHLCResource(Resource): class OHLCResource(Resource):
"""Resource to expose OHLC.""" """Resource to expose OHLC."""
@requires_auth
@ns.doc( @ns.doc(
security='apikey', security='apikey',
responses={ responses={
@ -349,6 +347,7 @@ class OHLCResource(Resource):
}) })
@ns.expect(range_parser) @ns.expect(range_parser)
@ns.marshal_list_with(ohlc_model) @ns.marshal_list_with(ohlc_model)
@jwt_required
def get(self, id): def get(self, id):
"""Get OHLC data for a specific date range and account.""" """Get OHLC data for a specific date range and account."""

View File

@ -4,14 +4,13 @@
import dateutil.parser import dateutil.parser
from flask_jwt_extended import jwt_required
from flask_restplus import Namespace, Resource, fields from flask_restplus import Namespace, Resource, fields
from ..models import db from ..models import db
from ..models.accounts import Account from ..models.accounts import Account
from ..models.operations import Operation from ..models.operations import Operation
from .users import requires_auth
# pylint: disable=invalid-name # pylint: disable=invalid-name
ns = Namespace('operation', description='Operation management') ns = Namespace('operation', description='Operation management')
@ -98,10 +97,10 @@ account_range_parser.add_argument(
class OperationListResource(Resource): class OperationListResource(Resource):
"""Resource to handle operation lists.""" """Resource to handle operation lists."""
@requires_auth
@ns.response(200, 'OK', [operation_with_sold_model]) @ns.response(200, 'OK', [operation_with_sold_model])
@ns.expect(parser=account_range_parser) @ns.expect(parser=account_range_parser)
@ns.marshal_list_with(operation_with_sold_model) @ns.marshal_list_with(operation_with_sold_model)
@jwt_required
def get(self): def get(self):
"""Get operations with solds for a specific account.""" """Get operations with solds for a specific account."""
@ -114,11 +113,11 @@ class OperationListResource(Resource):
Operation.account_id == data['account_id'] Operation.account_id == data['account_id']
).all(), 200 ).all(), 200
@requires_auth
@ns.response(201, 'Operation created', operation_model) @ns.response(201, 'Operation created', operation_model)
@ns.response(404, 'Account not found') @ns.response(404, 'Account not found')
@ns.response(406, 'Invalid operation data') @ns.response(406, 'Invalid operation data')
@ns.marshal_with(operation_model) @ns.marshal_with(operation_model)
@jwt_required
def post(self): def post(self):
"""Create a new operation.""" """Create a new operation."""
@ -160,9 +159,9 @@ class OperationListResource(Resource):
class OperationResource(Resource): class OperationResource(Resource):
"""Resource to handle operations.""" """Resource to handle operations."""
@requires_auth
@ns.response(200, 'OK', operation_model) @ns.response(200, 'OK', operation_model)
@ns.marshal_with(operation_model) @ns.marshal_with(operation_model)
@jwt_required
def get(self, id): def get(self, id):
"""Get operation.""" """Get operation."""
@ -176,11 +175,11 @@ class OperationResource(Resource):
return operation, 200 return operation, 200
@requires_auth
@ns.expect(operation_model) @ns.expect(operation_model)
@ns.response(200, 'OK', operation_model) @ns.response(200, 'OK', operation_model)
@ns.response(406, 'Invalid operation data') @ns.response(406, 'Invalid operation data')
@ns.marshal_with(operation_model) @ns.marshal_with(operation_model)
@jwt_required
def post(self, id): def post(self, id):
"""Update an operation.""" """Update an operation."""
@ -211,9 +210,9 @@ class OperationResource(Resource):
return operation, 200 return operation, 200
@requires_auth
@ns.response(204, 'Operation deleted', operation_model) @ns.response(204, 'Operation deleted', operation_model)
@ns.marshal_with(operation_model) @ns.marshal_with(operation_model)
@jwt_required
def delete(self, id): def delete(self, id):
"""Delete an operation.""" """Delete an operation."""

View File

@ -2,6 +2,7 @@
# vim: set tw=80 ts=4 sw=4 sts=4: # vim: set tw=80 ts=4 sw=4 sts=4:
from flask_jwt_extended import jwt_required
from flask_restplus import Namespace, Resource, fields from flask_restplus import Namespace, Resource, fields
from sqlalchemy import true from sqlalchemy import true
@ -11,7 +12,6 @@ from ..models.accounts import Account
from ..models.operations import Operation from ..models.operations import Operation
from ..models.scheduled_operations import ScheduledOperation from ..models.scheduled_operations import ScheduledOperation
from .users import requires_auth
# pylint: disable=invalid-name # pylint: disable=invalid-name
ns = Namespace( ns = Namespace(
@ -75,10 +75,10 @@ account_id_parser.add_argument(
class ScheduledOperationListResource(Resource): class ScheduledOperationListResource(Resource):
"""Resource to handle scheduled operation lists.""" """Resource to handle scheduled operation lists."""
@requires_auth
@ns.expect(account_id_parser) @ns.expect(account_id_parser)
@ns.response(200, 'OK', [scheduled_operation_model]) @ns.response(200, 'OK', [scheduled_operation_model])
@ns.marshal_list_with(scheduled_operation_model) @ns.marshal_list_with(scheduled_operation_model)
@jwt_required
def get(self): def get(self):
"""Get all scheduled operation for an account.""" """Get all scheduled operation for an account."""
@ -86,12 +86,12 @@ class ScheduledOperationListResource(Resource):
return ScheduledOperation.query().filter_by(**data).all(), 200 return ScheduledOperation.query().filter_by(**data).all(), 200
@requires_auth
@ns.expect(scheduled_operation_model) @ns.expect(scheduled_operation_model)
@ns.response(200, 'OK', scheduled_operation_model) @ns.response(200, 'OK', scheduled_operation_model)
@ns.response(404, 'Account not found') @ns.response(404, 'Account not found')
@ns.response(406, 'Invalid operation data') @ns.response(406, 'Invalid operation data')
@ns.marshal_with(scheduled_operation_model) @ns.marshal_with(scheduled_operation_model)
@jwt_required
def post(self): def post(self):
"""Add a new scheduled operation.""" """Add a new scheduled operation."""
@ -137,9 +137,9 @@ class ScheduledOperationListResource(Resource):
class ScheduledOperationResource(Resource): class ScheduledOperationResource(Resource):
"""Resource to handle scheduled operations.""" """Resource to handle scheduled operations."""
@requires_auth
@ns.response(200, 'OK', scheduled_operation_model) @ns.response(200, 'OK', scheduled_operation_model)
@ns.marshal_with(scheduled_operation_model) @ns.marshal_with(scheduled_operation_model)
@jwt_required
def get(self, id): def get(self, id):
"""Get scheduled operation.""" """Get scheduled operation."""
@ -153,11 +153,11 @@ class ScheduledOperationResource(Resource):
return scheduled_operation, 200 return scheduled_operation, 200
@requires_auth
@ns.response(200, 'OK', scheduled_operation_model) @ns.response(200, 'OK', scheduled_operation_model)
@ns.response(406, 'Invalid scheduled operation data') @ns.response(406, 'Invalid scheduled operation data')
@ns.expect(scheduled_operation_model) @ns.expect(scheduled_operation_model)
@ns.marshal_with(scheduled_operation_model) @ns.marshal_with(scheduled_operation_model)
@jwt_required
def post(self, id): def post(self, id):
"""Update a scheduled operation.""" """Update a scheduled operation."""
@ -192,10 +192,10 @@ class ScheduledOperationResource(Resource):
return scheduled_operation, 200 return scheduled_operation, 200
@requires_auth
@ns.response(200, 'OK', scheduled_operation_model) @ns.response(200, 'OK', scheduled_operation_model)
@ns.response(409, 'Cannot be deleted') @ns.response(409, 'Cannot be deleted')
@ns.marshal_with(scheduled_operation_model) @ns.marshal_with(scheduled_operation_model)
@jwt_required
def delete(self, id): def delete(self, id):
"""Delete a scheduled operation.""" """Delete a scheduled operation."""

View File

@ -2,11 +2,9 @@
# vim: set tw=80 ts=4 sw=4 sts=4: # vim: set tw=80 ts=4 sw=4 sts=4:
from functools import wraps from flask import request
from flask_jwt_extended import (jwt_required, get_jwt_identity,
import arrow create_access_token, create_refresh_token)
from flask import request, g, current_app as app
from flask_restplus import Namespace, Resource, fields from flask_restplus import Namespace, Resource, fields
from ..models.users import User from ..models.users import User
@ -25,49 +23,24 @@ def load_user_from_auth(auth):
return load_user_from_token(token) return load_user_from_token(token)
def requires_auth(func):
"""Decorator to check user authentication before handling a request."""
@wraps(func)
def wrapped(*args, **data):
"""Check user authentication from requests headers and return 401 if
unauthorized."""
user = None
if 'Authorization' in request.headers:
auth = request.headers['Authorization']
user = load_user_from_auth(auth)
if user:
g.user = user
return func(*args, **data)
ns.abort(
401,
error_message='Please login before executing this request.'
)
return wrapped
# pylint: disable=invalid-name # pylint: disable=invalid-name
ns = Namespace('user', description='User management') ns = Namespace('user', description='User management')
# Token with expiration time and type. # Token with expiration time and type.
token_model = ns.model('Token', { token_model = ns.model('Token', {
'token': fields.String( 'access_token': fields.String(
required=True, required=True,
readonly=True, readonly=True,
description='Token value'), description='Access token value'),
'refresh_token': fields.String(
required=False,
readonly=True,
description='Refresh token value'),
'expiration': fields.DateTime( 'expiration': fields.DateTime(
dt_format='iso8601', dt_format='iso8601',
required=True, required=False,
readonly=True, readonly=True,
description='Expiration time of the token'), description='Expiration time of the token'),
'token_type': fields.String(
required=True,
readonly=True,
description='Token type')
}) })
# User model. # User model.
@ -112,25 +85,20 @@ class LoginResource(Resource):
if not user or not user.verify_password(password): if not user or not user.verify_password(password):
ns.abort(401, error_message="Bad user or password.") ns.abort(401, error_message="Bad user or password.")
token = user.generate_auth_token()
expiration_time = arrow.now().replace(
seconds=app.config['SESSION_TTL']
)
return { return {
'token': token, 'access_token': create_access_token(identity=user),
'expiration': expiration_time.datetime, 'refresh_token': create_refresh_token(identity=user)
'token_type': 'Bearer'
}, 200 }, 200
@requires_auth
@ns.doc( @ns.doc(
security='apikey', security='apikey',
responses={ responses={
200: ('OK', user_model) 200: ('OK', user_model)
}) })
@ns.marshal_with(user_model) @ns.marshal_with(user_model)
@jwt_required
def get(self): def get(self):
"""Get authenticated user information.""" """Get authenticated user information."""
user = User.query().get(get_jwt_identity())
return g.user, 200 return user, 200

View File

@ -23,6 +23,7 @@ setup(
'Flask-SQLAlchemy>=2.2', 'Flask-SQLAlchemy>=2.2',
'Flask-restplus>=0.10.1', 'Flask-restplus>=0.10.1',
'Flask-Cors>=2.1.2', 'Flask-Cors>=2.1.2',
'Flask-JWT-Extended>=2.0.0',
'passlib>=1.7.1', 'passlib>=1.7.1',
'arrow>=0.10.0', 'arrow>=0.10.0',
], ],