Use Flask-JWT-extended.
This commit is contained in:
parent
e64ea7b3d1
commit
b794b2a6c1
@ -10,7 +10,7 @@ from flask_alembic import Alembic
|
||||
from flask_alembic import alembic_click
|
||||
|
||||
from .models import db
|
||||
from .views import api, cors
|
||||
from .views import api, cors, jwt
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
alembic = Alembic()
|
||||
@ -39,6 +39,9 @@ def create_app(config_path):
|
||||
# API views related stuff.
|
||||
cors.init_app(app)
|
||||
api.init_app(app)
|
||||
jwt.init_app(app)
|
||||
# Needed to handle correctly JWT exceptions by restplus.
|
||||
jwt._set_error_handler_callbacks(api) # pylint: disable=protected-access
|
||||
|
||||
return app
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
# vim: set tw=80 ts=4 sw=4 sts=4:
|
||||
|
||||
from flask_cors import CORS
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_restplus import Api
|
||||
|
||||
from .accounts import ns as accounts_ns
|
||||
@ -40,3 +41,29 @@ api.add_namespace(users_ns)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
cors = CORS()
|
||||
|
||||
jwt = JWTManager()
|
||||
|
||||
|
||||
@jwt.user_identity_loader
|
||||
def user_identity_lookup(user):
|
||||
"""Return information to be in token."""
|
||||
return user.id
|
||||
|
||||
|
||||
@jwt.expired_token_loader
|
||||
def expired_token_callback():
|
||||
"""Handler for expired token."""
|
||||
api.abort(401, "Expired token.")
|
||||
|
||||
|
||||
@jwt.unauthorized_loader
|
||||
def unauthorized_callback(message):
|
||||
"""Handler for unauthorized."""
|
||||
api.abort(401, message)
|
||||
|
||||
|
||||
@jwt.invalid_token_loader
|
||||
def invalid_token_callback(message):
|
||||
"""Handler for invalid token."""
|
||||
api.abort(401, message)
|
||||
|
@ -4,14 +4,13 @@
|
||||
|
||||
import dateutil.parser
|
||||
|
||||
from flask_jwt_extended import jwt_required
|
||||
from flask_restplus import Namespace, Resource, fields
|
||||
|
||||
from ..models import db
|
||||
from ..models.accounts import Account
|
||||
from ..models.operations import Operation
|
||||
|
||||
from .users import requires_auth
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
ns = Namespace('account', description='Account management')
|
||||
@ -128,19 +127,19 @@ range_parser.add_argument(
|
||||
class AccountListResource(Resource):
|
||||
"""Resource used to handle account lists."""
|
||||
|
||||
@requires_auth
|
||||
@ns.response(200, 'OK', [account_model])
|
||||
@ns.marshal_list_with(account_model)
|
||||
@jwt_required
|
||||
def get(self):
|
||||
""" Returns accounts with their balances."""
|
||||
|
||||
return Account.query().all(), 200
|
||||
|
||||
@requires_auth
|
||||
@ns.expect(account_model)
|
||||
@ns.response(201, 'Account created', account_model)
|
||||
@ns.response(406, 'Invalid account data')
|
||||
@ns.marshal_with(account_model)
|
||||
@jwt_required
|
||||
def post(self):
|
||||
"""Create a new account."""
|
||||
|
||||
@ -179,9 +178,9 @@ class AccountListResource(Resource):
|
||||
class AccountResource(Resource):
|
||||
"""Resource to handle accounts."""
|
||||
|
||||
@requires_auth
|
||||
@ns.response(200, 'OK', account_model)
|
||||
@ns.marshal_with(account_model)
|
||||
@jwt_required
|
||||
def get(self, id):
|
||||
"""Get an account."""
|
||||
|
||||
@ -197,11 +196,11 @@ class AccountResource(Resource):
|
||||
# causes error on marshalling.
|
||||
return account, 200
|
||||
|
||||
@requires_auth
|
||||
@ns.expect(account_model)
|
||||
@ns.response(200, 'OK', account_model)
|
||||
@ns.response(406, 'Invalid account data')
|
||||
@ns.marshal_with(account_model)
|
||||
@jwt_required
|
||||
def post(self, id):
|
||||
"""Update an account."""
|
||||
|
||||
@ -232,9 +231,9 @@ class AccountResource(Resource):
|
||||
# Return account.
|
||||
return account, 200
|
||||
|
||||
@requires_auth
|
||||
@ns.response(204, 'Account deleted', account_model)
|
||||
@ns.marshal_with(account_model)
|
||||
@jwt_required
|
||||
def delete(self, id):
|
||||
"""Delete an account."""
|
||||
|
||||
@ -256,7 +255,6 @@ class AccountResource(Resource):
|
||||
class SoldsResource(Resource):
|
||||
"""Resource to expose solds."""
|
||||
|
||||
@requires_auth
|
||||
@ns.doc(
|
||||
security='apikey',
|
||||
responses={
|
||||
@ -265,6 +263,7 @@ class SoldsResource(Resource):
|
||||
404: 'Account not found'
|
||||
})
|
||||
@ns.marshal_with(solds_model)
|
||||
@jwt_required
|
||||
def get(self, id):
|
||||
"""Get solds for a specific account and date range."""
|
||||
|
||||
@ -285,7 +284,6 @@ class SoldsResource(Resource):
|
||||
class BalanceResource(Resource):
|
||||
"""Resource to expose balances."""
|
||||
|
||||
@requires_auth
|
||||
@ns.doc(
|
||||
security='apikey',
|
||||
responses={
|
||||
@ -295,6 +293,7 @@ class BalanceResource(Resource):
|
||||
})
|
||||
@ns.expect(range_parser)
|
||||
@ns.marshal_with(balance_model)
|
||||
@jwt_required
|
||||
def get(self, id):
|
||||
"""Get account balance for a specific date range."""
|
||||
|
||||
@ -317,7 +316,6 @@ class BalanceResource(Resource):
|
||||
class CategoryResource(Resource):
|
||||
"""Resource to expose categories."""
|
||||
|
||||
@requires_auth
|
||||
@ns.doc(
|
||||
security='apikey',
|
||||
responses={
|
||||
@ -327,6 +325,7 @@ class CategoryResource(Resource):
|
||||
})
|
||||
@ns.expect(range_parser)
|
||||
@ns.marshal_list_with(category_model)
|
||||
@jwt_required
|
||||
def get(self, id):
|
||||
"""Get account category balances for a specific date range."""
|
||||
|
||||
@ -339,7 +338,6 @@ class CategoryResource(Resource):
|
||||
class OHLCResource(Resource):
|
||||
"""Resource to expose OHLC."""
|
||||
|
||||
@requires_auth
|
||||
@ns.doc(
|
||||
security='apikey',
|
||||
responses={
|
||||
@ -349,6 +347,7 @@ class OHLCResource(Resource):
|
||||
})
|
||||
@ns.expect(range_parser)
|
||||
@ns.marshal_list_with(ohlc_model)
|
||||
@jwt_required
|
||||
def get(self, id):
|
||||
"""Get OHLC data for a specific date range and account."""
|
||||
|
||||
|
@ -4,14 +4,13 @@
|
||||
|
||||
import dateutil.parser
|
||||
|
||||
from flask_jwt_extended import jwt_required
|
||||
from flask_restplus import Namespace, Resource, fields
|
||||
|
||||
from ..models import db
|
||||
from ..models.accounts import Account
|
||||
from ..models.operations import Operation
|
||||
|
||||
from .users import requires_auth
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
ns = Namespace('operation', description='Operation management')
|
||||
@ -98,10 +97,10 @@ account_range_parser.add_argument(
|
||||
class OperationListResource(Resource):
|
||||
"""Resource to handle operation lists."""
|
||||
|
||||
@requires_auth
|
||||
@ns.response(200, 'OK', [operation_with_sold_model])
|
||||
@ns.expect(parser=account_range_parser)
|
||||
@ns.marshal_list_with(operation_with_sold_model)
|
||||
@jwt_required
|
||||
def get(self):
|
||||
"""Get operations with solds for a specific account."""
|
||||
|
||||
@ -114,11 +113,11 @@ class OperationListResource(Resource):
|
||||
Operation.account_id == data['account_id']
|
||||
).all(), 200
|
||||
|
||||
@requires_auth
|
||||
@ns.response(201, 'Operation created', operation_model)
|
||||
@ns.response(404, 'Account not found')
|
||||
@ns.response(406, 'Invalid operation data')
|
||||
@ns.marshal_with(operation_model)
|
||||
@jwt_required
|
||||
def post(self):
|
||||
"""Create a new operation."""
|
||||
|
||||
@ -160,9 +159,9 @@ class OperationListResource(Resource):
|
||||
class OperationResource(Resource):
|
||||
"""Resource to handle operations."""
|
||||
|
||||
@requires_auth
|
||||
@ns.response(200, 'OK', operation_model)
|
||||
@ns.marshal_with(operation_model)
|
||||
@jwt_required
|
||||
def get(self, id):
|
||||
"""Get operation."""
|
||||
|
||||
@ -176,11 +175,11 @@ class OperationResource(Resource):
|
||||
|
||||
return operation, 200
|
||||
|
||||
@requires_auth
|
||||
@ns.expect(operation_model)
|
||||
@ns.response(200, 'OK', operation_model)
|
||||
@ns.response(406, 'Invalid operation data')
|
||||
@ns.marshal_with(operation_model)
|
||||
@jwt_required
|
||||
def post(self, id):
|
||||
"""Update an operation."""
|
||||
|
||||
@ -211,9 +210,9 @@ class OperationResource(Resource):
|
||||
|
||||
return operation, 200
|
||||
|
||||
@requires_auth
|
||||
@ns.response(204, 'Operation deleted', operation_model)
|
||||
@ns.marshal_with(operation_model)
|
||||
@jwt_required
|
||||
def delete(self, id):
|
||||
"""Delete an operation."""
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
# vim: set tw=80 ts=4 sw=4 sts=4:
|
||||
|
||||
from flask_jwt_extended import jwt_required
|
||||
from flask_restplus import Namespace, Resource, fields
|
||||
|
||||
from sqlalchemy import true
|
||||
@ -11,7 +12,6 @@ from ..models.accounts import Account
|
||||
from ..models.operations import Operation
|
||||
from ..models.scheduled_operations import ScheduledOperation
|
||||
|
||||
from .users import requires_auth
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
ns = Namespace(
|
||||
@ -75,10 +75,10 @@ account_id_parser.add_argument(
|
||||
class ScheduledOperationListResource(Resource):
|
||||
"""Resource to handle scheduled operation lists."""
|
||||
|
||||
@requires_auth
|
||||
@ns.expect(account_id_parser)
|
||||
@ns.response(200, 'OK', [scheduled_operation_model])
|
||||
@ns.marshal_list_with(scheduled_operation_model)
|
||||
@jwt_required
|
||||
def get(self):
|
||||
"""Get all scheduled operation for an account."""
|
||||
|
||||
@ -86,12 +86,12 @@ class ScheduledOperationListResource(Resource):
|
||||
|
||||
return ScheduledOperation.query().filter_by(**data).all(), 200
|
||||
|
||||
@requires_auth
|
||||
@ns.expect(scheduled_operation_model)
|
||||
@ns.response(200, 'OK', scheduled_operation_model)
|
||||
@ns.response(404, 'Account not found')
|
||||
@ns.response(406, 'Invalid operation data')
|
||||
@ns.marshal_with(scheduled_operation_model)
|
||||
@jwt_required
|
||||
def post(self):
|
||||
"""Add a new scheduled operation."""
|
||||
|
||||
@ -137,9 +137,9 @@ class ScheduledOperationListResource(Resource):
|
||||
class ScheduledOperationResource(Resource):
|
||||
"""Resource to handle scheduled operations."""
|
||||
|
||||
@requires_auth
|
||||
@ns.response(200, 'OK', scheduled_operation_model)
|
||||
@ns.marshal_with(scheduled_operation_model)
|
||||
@jwt_required
|
||||
def get(self, id):
|
||||
"""Get scheduled operation."""
|
||||
|
||||
@ -153,11 +153,11 @@ class ScheduledOperationResource(Resource):
|
||||
|
||||
return scheduled_operation, 200
|
||||
|
||||
@requires_auth
|
||||
@ns.response(200, 'OK', scheduled_operation_model)
|
||||
@ns.response(406, 'Invalid scheduled operation data')
|
||||
@ns.expect(scheduled_operation_model)
|
||||
@ns.marshal_with(scheduled_operation_model)
|
||||
@jwt_required
|
||||
def post(self, id):
|
||||
"""Update a scheduled operation."""
|
||||
|
||||
@ -192,10 +192,10 @@ class ScheduledOperationResource(Resource):
|
||||
|
||||
return scheduled_operation, 200
|
||||
|
||||
@requires_auth
|
||||
@ns.response(200, 'OK', scheduled_operation_model)
|
||||
@ns.response(409, 'Cannot be deleted')
|
||||
@ns.marshal_with(scheduled_operation_model)
|
||||
@jwt_required
|
||||
def delete(self, id):
|
||||
"""Delete a scheduled operation."""
|
||||
|
||||
|
@ -2,11 +2,9 @@
|
||||
|
||||
# vim: set tw=80 ts=4 sw=4 sts=4:
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import arrow
|
||||
|
||||
from flask import request, g, current_app as app
|
||||
from flask import request
|
||||
from flask_jwt_extended import (jwt_required, get_jwt_identity,
|
||||
create_access_token, create_refresh_token)
|
||||
from flask_restplus import Namespace, Resource, fields
|
||||
|
||||
from ..models.users import User
|
||||
@ -25,49 +23,24 @@ def load_user_from_auth(auth):
|
||||
return load_user_from_token(token)
|
||||
|
||||
|
||||
def requires_auth(func):
|
||||
"""Decorator to check user authentication before handling a request."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapped(*args, **data):
|
||||
"""Check user authentication from requests headers and return 401 if
|
||||
unauthorized."""
|
||||
|
||||
user = None
|
||||
|
||||
if 'Authorization' in request.headers:
|
||||
auth = request.headers['Authorization']
|
||||
user = load_user_from_auth(auth)
|
||||
|
||||
if user:
|
||||
g.user = user
|
||||
return func(*args, **data)
|
||||
|
||||
ns.abort(
|
||||
401,
|
||||
error_message='Please login before executing this request.'
|
||||
)
|
||||
return wrapped
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
ns = Namespace('user', description='User management')
|
||||
|
||||
# Token with expiration time and type.
|
||||
token_model = ns.model('Token', {
|
||||
'token': fields.String(
|
||||
'access_token': fields.String(
|
||||
required=True,
|
||||
readonly=True,
|
||||
description='Token value'),
|
||||
description='Access token value'),
|
||||
'refresh_token': fields.String(
|
||||
required=False,
|
||||
readonly=True,
|
||||
description='Refresh token value'),
|
||||
'expiration': fields.DateTime(
|
||||
dt_format='iso8601',
|
||||
required=True,
|
||||
required=False,
|
||||
readonly=True,
|
||||
description='Expiration time of the token'),
|
||||
'token_type': fields.String(
|
||||
required=True,
|
||||
readonly=True,
|
||||
description='Token type')
|
||||
})
|
||||
|
||||
# User model.
|
||||
@ -112,25 +85,20 @@ class LoginResource(Resource):
|
||||
if not user or not user.verify_password(password):
|
||||
ns.abort(401, error_message="Bad user or password.")
|
||||
|
||||
token = user.generate_auth_token()
|
||||
expiration_time = arrow.now().replace(
|
||||
seconds=app.config['SESSION_TTL']
|
||||
)
|
||||
|
||||
return {
|
||||
'token': token,
|
||||
'expiration': expiration_time.datetime,
|
||||
'token_type': 'Bearer'
|
||||
'access_token': create_access_token(identity=user),
|
||||
'refresh_token': create_refresh_token(identity=user)
|
||||
}, 200
|
||||
|
||||
@requires_auth
|
||||
@ns.doc(
|
||||
security='apikey',
|
||||
responses={
|
||||
200: ('OK', user_model)
|
||||
})
|
||||
@ns.marshal_with(user_model)
|
||||
@jwt_required
|
||||
def get(self):
|
||||
"""Get authenticated user information."""
|
||||
user = User.query().get(get_jwt_identity())
|
||||
|
||||
return g.user, 200
|
||||
return user, 200
|
||||
|
Loading…
Reference in New Issue
Block a user