Use Flask-JWT-extended.

This commit is contained in:
Alexis Lahouze 2017-05-19 00:07:30 +02:00
parent e64ea7b3d1
commit b794b2a6c1
7 changed files with 69 additions and 72 deletions

View File

@ -10,7 +10,7 @@ from flask_alembic import Alembic
from flask_alembic import alembic_click
from .models import db
from .views import api, cors
from .views import api, cors, jwt
# pylint: disable=invalid-name
alembic = Alembic()
@ -39,6 +39,9 @@ def create_app(config_path):
# API views related stuff.
cors.init_app(app)
api.init_app(app)
jwt.init_app(app)
# Needed to handle correctly JWT exceptions by restplus.
jwt._set_error_handler_callbacks(api) # pylint: disable=protected-access
return app

View File

@ -3,6 +3,7 @@
# vim: set tw=80 ts=4 sw=4 sts=4:
from flask_cors import CORS
from flask_jwt_extended import JWTManager
from flask_restplus import Api
from .accounts import ns as accounts_ns
@ -40,3 +41,29 @@ api.add_namespace(users_ns)
# pylint: disable=invalid-name
cors = CORS()
jwt = JWTManager()
@jwt.user_identity_loader
def user_identity_lookup(user):
"""Return information to be in token."""
return user.id
@jwt.expired_token_loader
def expired_token_callback():
"""Handler for expired token."""
api.abort(401, "Expired token.")
@jwt.unauthorized_loader
def unauthorized_callback(message):
"""Handler for unauthorized."""
api.abort(401, message)
@jwt.invalid_token_loader
def invalid_token_callback(message):
"""Handler for invalid token."""
api.abort(401, message)

View File

@ -4,14 +4,13 @@
import dateutil.parser
from flask_jwt_extended import jwt_required
from flask_restplus import Namespace, Resource, fields
from ..models import db
from ..models.accounts import Account
from ..models.operations import Operation
from .users import requires_auth
# pylint: disable=invalid-name
ns = Namespace('account', description='Account management')
@ -128,19 +127,19 @@ range_parser.add_argument(
class AccountListResource(Resource):
"""Resource used to handle account lists."""
@requires_auth
@ns.response(200, 'OK', [account_model])
@ns.marshal_list_with(account_model)
@jwt_required
def get(self):
""" Returns accounts with their balances."""
return Account.query().all(), 200
@requires_auth
@ns.expect(account_model)
@ns.response(201, 'Account created', account_model)
@ns.response(406, 'Invalid account data')
@ns.marshal_with(account_model)
@jwt_required
def post(self):
"""Create a new account."""
@ -179,9 +178,9 @@ class AccountListResource(Resource):
class AccountResource(Resource):
"""Resource to handle accounts."""
@requires_auth
@ns.response(200, 'OK', account_model)
@ns.marshal_with(account_model)
@jwt_required
def get(self, id):
"""Get an account."""
@ -197,11 +196,11 @@ class AccountResource(Resource):
# causes error on marshalling.
return account, 200
@requires_auth
@ns.expect(account_model)
@ns.response(200, 'OK', account_model)
@ns.response(406, 'Invalid account data')
@ns.marshal_with(account_model)
@jwt_required
def post(self, id):
"""Update an account."""
@ -232,9 +231,9 @@ class AccountResource(Resource):
# Return account.
return account, 200
@requires_auth
@ns.response(204, 'Account deleted', account_model)
@ns.marshal_with(account_model)
@jwt_required
def delete(self, id):
"""Delete an account."""
@ -256,7 +255,6 @@ class AccountResource(Resource):
class SoldsResource(Resource):
"""Resource to expose solds."""
@requires_auth
@ns.doc(
security='apikey',
responses={
@ -265,6 +263,7 @@ class SoldsResource(Resource):
404: 'Account not found'
})
@ns.marshal_with(solds_model)
@jwt_required
def get(self, id):
"""Get solds for a specific account and date range."""
@ -285,7 +284,6 @@ class SoldsResource(Resource):
class BalanceResource(Resource):
"""Resource to expose balances."""
@requires_auth
@ns.doc(
security='apikey',
responses={
@ -295,6 +293,7 @@ class BalanceResource(Resource):
})
@ns.expect(range_parser)
@ns.marshal_with(balance_model)
@jwt_required
def get(self, id):
"""Get account balance for a specific date range."""
@ -317,7 +316,6 @@ class BalanceResource(Resource):
class CategoryResource(Resource):
"""Resource to expose categories."""
@requires_auth
@ns.doc(
security='apikey',
responses={
@ -327,6 +325,7 @@ class CategoryResource(Resource):
})
@ns.expect(range_parser)
@ns.marshal_list_with(category_model)
@jwt_required
def get(self, id):
"""Get account category balances for a specific date range."""
@ -339,7 +338,6 @@ class CategoryResource(Resource):
class OHLCResource(Resource):
"""Resource to expose OHLC."""
@requires_auth
@ns.doc(
security='apikey',
responses={
@ -349,6 +347,7 @@ class OHLCResource(Resource):
})
@ns.expect(range_parser)
@ns.marshal_list_with(ohlc_model)
@jwt_required
def get(self, id):
"""Get OHLC data for a specific date range and account."""

View File

@ -4,14 +4,13 @@
import dateutil.parser
from flask_jwt_extended import jwt_required
from flask_restplus import Namespace, Resource, fields
from ..models import db
from ..models.accounts import Account
from ..models.operations import Operation
from .users import requires_auth
# pylint: disable=invalid-name
ns = Namespace('operation', description='Operation management')
@ -98,10 +97,10 @@ account_range_parser.add_argument(
class OperationListResource(Resource):
"""Resource to handle operation lists."""
@requires_auth
@ns.response(200, 'OK', [operation_with_sold_model])
@ns.expect(parser=account_range_parser)
@ns.marshal_list_with(operation_with_sold_model)
@jwt_required
def get(self):
"""Get operations with solds for a specific account."""
@ -114,11 +113,11 @@ class OperationListResource(Resource):
Operation.account_id == data['account_id']
).all(), 200
@requires_auth
@ns.response(201, 'Operation created', operation_model)
@ns.response(404, 'Account not found')
@ns.response(406, 'Invalid operation data')
@ns.marshal_with(operation_model)
@jwt_required
def post(self):
"""Create a new operation."""
@ -160,9 +159,9 @@ class OperationListResource(Resource):
class OperationResource(Resource):
"""Resource to handle operations."""
@requires_auth
@ns.response(200, 'OK', operation_model)
@ns.marshal_with(operation_model)
@jwt_required
def get(self, id):
"""Get operation."""
@ -176,11 +175,11 @@ class OperationResource(Resource):
return operation, 200
@requires_auth
@ns.expect(operation_model)
@ns.response(200, 'OK', operation_model)
@ns.response(406, 'Invalid operation data')
@ns.marshal_with(operation_model)
@jwt_required
def post(self, id):
"""Update an operation."""
@ -211,9 +210,9 @@ class OperationResource(Resource):
return operation, 200
@requires_auth
@ns.response(204, 'Operation deleted', operation_model)
@ns.marshal_with(operation_model)
@jwt_required
def delete(self, id):
"""Delete an operation."""

View File

@ -2,6 +2,7 @@
# vim: set tw=80 ts=4 sw=4 sts=4:
from flask_jwt_extended import jwt_required
from flask_restplus import Namespace, Resource, fields
from sqlalchemy import true
@ -11,7 +12,6 @@ from ..models.accounts import Account
from ..models.operations import Operation
from ..models.scheduled_operations import ScheduledOperation
from .users import requires_auth
# pylint: disable=invalid-name
ns = Namespace(
@ -75,10 +75,10 @@ account_id_parser.add_argument(
class ScheduledOperationListResource(Resource):
"""Resource to handle scheduled operation lists."""
@requires_auth
@ns.expect(account_id_parser)
@ns.response(200, 'OK', [scheduled_operation_model])
@ns.marshal_list_with(scheduled_operation_model)
@jwt_required
def get(self):
"""Get all scheduled operation for an account."""
@ -86,12 +86,12 @@ class ScheduledOperationListResource(Resource):
return ScheduledOperation.query().filter_by(**data).all(), 200
@requires_auth
@ns.expect(scheduled_operation_model)
@ns.response(200, 'OK', scheduled_operation_model)
@ns.response(404, 'Account not found')
@ns.response(406, 'Invalid operation data')
@ns.marshal_with(scheduled_operation_model)
@jwt_required
def post(self):
"""Add a new scheduled operation."""
@ -137,9 +137,9 @@ class ScheduledOperationListResource(Resource):
class ScheduledOperationResource(Resource):
"""Resource to handle scheduled operations."""
@requires_auth
@ns.response(200, 'OK', scheduled_operation_model)
@ns.marshal_with(scheduled_operation_model)
@jwt_required
def get(self, id):
"""Get scheduled operation."""
@ -153,11 +153,11 @@ class ScheduledOperationResource(Resource):
return scheduled_operation, 200
@requires_auth
@ns.response(200, 'OK', scheduled_operation_model)
@ns.response(406, 'Invalid scheduled operation data')
@ns.expect(scheduled_operation_model)
@ns.marshal_with(scheduled_operation_model)
@jwt_required
def post(self, id):
"""Update a scheduled operation."""
@ -192,10 +192,10 @@ class ScheduledOperationResource(Resource):
return scheduled_operation, 200
@requires_auth
@ns.response(200, 'OK', scheduled_operation_model)
@ns.response(409, 'Cannot be deleted')
@ns.marshal_with(scheduled_operation_model)
@jwt_required
def delete(self, id):
"""Delete a scheduled operation."""

View File

@ -2,11 +2,9 @@
# vim: set tw=80 ts=4 sw=4 sts=4:
from functools import wraps
import arrow
from flask import request, g, current_app as app
from flask import request
from flask_jwt_extended import (jwt_required, get_jwt_identity,
create_access_token, create_refresh_token)
from flask_restplus import Namespace, Resource, fields
from ..models.users import User
@ -25,49 +23,24 @@ def load_user_from_auth(auth):
return load_user_from_token(token)
def requires_auth(func):
"""Decorator to check user authentication before handling a request."""
@wraps(func)
def wrapped(*args, **data):
"""Check user authentication from requests headers and return 401 if
unauthorized."""
user = None
if 'Authorization' in request.headers:
auth = request.headers['Authorization']
user = load_user_from_auth(auth)
if user:
g.user = user
return func(*args, **data)
ns.abort(
401,
error_message='Please login before executing this request.'
)
return wrapped
# pylint: disable=invalid-name
ns = Namespace('user', description='User management')
# Token with expiration time and type.
token_model = ns.model('Token', {
'token': fields.String(
'access_token': fields.String(
required=True,
readonly=True,
description='Token value'),
description='Access token value'),
'refresh_token': fields.String(
required=False,
readonly=True,
description='Refresh token value'),
'expiration': fields.DateTime(
dt_format='iso8601',
required=True,
required=False,
readonly=True,
description='Expiration time of the token'),
'token_type': fields.String(
required=True,
readonly=True,
description='Token type')
})
# User model.
@ -112,25 +85,20 @@ class LoginResource(Resource):
if not user or not user.verify_password(password):
ns.abort(401, error_message="Bad user or password.")
token = user.generate_auth_token()
expiration_time = arrow.now().replace(
seconds=app.config['SESSION_TTL']
)
return {
'token': token,
'expiration': expiration_time.datetime,
'token_type': 'Bearer'
'access_token': create_access_token(identity=user),
'refresh_token': create_refresh_token(identity=user)
}, 200
@requires_auth
@ns.doc(
security='apikey',
responses={
200: ('OK', user_model)
})
@ns.marshal_with(user_model)
@jwt_required
def get(self):
"""Get authenticated user information."""
user = User.query().get(get_jwt_identity())
return g.user, 200
return user, 200

View File

@ -23,6 +23,7 @@ setup(
'Flask-SQLAlchemy>=2.2',
'Flask-restplus>=0.10.1',
'Flask-Cors>=2.1.2',
'Flask-JWT-Extended>=2.0.0',
'passlib>=1.7.1',
'arrow>=0.10.0',
],