import marshmallow as ma
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import Column
from jsonapi.datatypes import DataType, Date, Integer, String
from jsonapi.db.table import Cardinality, FromItem, get_table_name
from jsonapi.exc import APIError, Error
from jsonapi.registry import model_registry, schema_registry
class BaseField:
""" The base class for all field types """
def __init__(self, name, data_type=None):
if data_type is not None and not isinstance(data_type, DataType):
raise Error('invalid data type provided: "{}"'.format(data_type))
self.name = name
self.data_type = data_type
self.expr = None
self.exclude = False
self.sort_by = False
self.filter_clause = None
def get_filter_clause(self):
data_type = Integer if self.name == 'id' else self.data_type
if data_type is not None:
return data_type.filter_clause
def is_aggregate(self):
return isinstance(self, Aggregate)
def is_relationship(self):
return isinstance(self, Relationship)
def get_ma_field(self):
if isinstance(self, Relationship):
return ma.fields.Nested(
schema_registry['{}Schema'.format(self.model.name)](),
many=self.cardinality in (Cardinality.ONE_TO_MANY,
Cardinality.MANY_TO_MANY))
if issubclass(self.data_type.ma_type, ma.fields.Date):
return self.data_type.ma_type(format=DataType.FORMAT_DATE)
if issubclass(self.data_type.ma_type, ma.fields.DateTime):
return self.data_type.ma_type(format=DataType.FORMAT_DATETIME)
return self.data_type.ma_type()
def __repr__(self):
return '<{}({})>'.format(self.__class__.__name__, self.name)
[docs]class Field(BaseField):
"""
Basic field type, which maps to a database table column or a column expression.
>>> from jsonapi.datatypes import Date
>>> from jsonapi.tests.db import users_t
>>>
>>> Field('email')
>>> Field('email-address', users_t.c.email)
>>> Field('name', lambda c: c.first + ' ' + c.last)
>>> Field('created-on', data_type=Date)
"""
[docs] def __init__(self, name, col=None, data_type=None):
"""
:param str name: a unique field name
:param lambda func: a lambda function that accepts a ColumnCollection (optional)
:param DataType data_type: defaults to String (optional)
"""
super().__init__(name, data_type=data_type)
self.col = col
def load(self, model):
if self.name == 'id':
self.expr = model.primary_key
elif isinstance(self.col, Column):
self.expr = model.get_expr(self.col)
elif self.col is not None:
self.expr = self.col(model.rec)
else:
self.expr = model.get_expr(self.name)
if self.data_type is None:
self.data_type = DataType.get(self.expr)
self.filter_clause = self.get_filter_clause()
class Aggregate(BaseField):
"""
Represents an aggregate field (e.g. count, max, etc.)
To define an aggregate field, an aggregate expression must be provided,
along with one or more from items to add to the model's from clause.
"""
def __init__(self, name, rel_name, func, col=None, data_type=None):
"""
:param str name: field name
:param rel_name: relationship name
:param func: SQLAlchemy aggregate function (ex. func.count)
:param DataType data_type: one of the supported data types (optional)
"""
super().__init__(name, data_type=data_type)
self.func = func
self.col = col
self.rel_name = rel_name
self.rel = None
self.from_items = dict()
def load(self, model):
self.rel = model.relationship(self.rel_name)
self.rel.load(model)
if self.col is None:
col_expr = self.rel.model.primary_key.distinct()
elif isinstance(self.col, str):
col_expr = model.get_expr(self.col).distinct()
else:
col_expr = self.col(model.rec)
self.expr = self.func(text(str(col_expr)))
if self.data_type is None:
self.data_type = DataType.get(self.expr)
self.filter_clause = self.get_filter_clause()
if self.rel.cardinality == Cardinality.MANY_TO_MANY:
ref_model, _ = self.rel.ref
self.from_items[model.name] = (
FromItem(ref_model.table, left=True),
FromItem(self.rel.model.primary_key.table, left=True))
elif self.rel.cardinality == Cardinality.ONE_TO_MANY:
from_item = FromItem(
self.rel.model.primary_key.table,
onclause=self.rel.parent.primary_key == self.rel.ref,
left=True)
self.from_items[model.name] = (from_item,)
else:
raise APIError('error: "{}"'.format(self.name), model)
[docs]class Relationship(BaseField):
"""
Represents a relationship field.
>>> from jsonapi.model import ONE_TO_MANY
>>> from jsonapi.tests.db import articles_t
>>>
>>> Relationship('articles', 'ArticleModel', ONE_TO_MANY,
>>> articles_t.c.author_id)
"""
[docs] def __init__(self, name, model_name, cardinality, *refs, **kwargs):
"""
:param str name: relationship name
:param str model_name: related model name
:param Cardinality cardinality: relationship cardinality
:param refs: a variable length list of foreign key columns
"""
super().__init__(name)
self.cardinality = cardinality
self.model_name = model_name
self.check_refs(refs)
self.refs = refs
self.model = None
self.nested = None
self.parent = None
self.where = kwargs.get('where', None)
def check_refs(self, refs):
for ref in refs:
if not isinstance(ref, Column):
raise Error('invalid "ref" value: {!r}'.format(ref))
if self.cardinality == Cardinality.MANY_TO_MANY and len(refs) != 2:
raise Error('two "ref" columns required: {}'.format(', '.join(r.name) for r in refs))
if self.cardinality in (Cardinality.MANY_TO_ONE, Cardinality.ONE_TO_MANY) and len(refs) != 1:
raise Error('one "ref" column required: {}'.format(', '.join(r.name) for r in refs))
if self.cardinality == Cardinality.ONE_TO_ONE and len(refs) > 1:
raise Error('too many "ref" columns: {}'.format(', '.join(r.name) for r in refs))
@property
def ref(self):
if self.model and self.parent:
if self.cardinality == Cardinality.MANY_TO_MANY:
return self.refs
if self.cardinality != Cardinality.ONE_TO_ONE:
for from_clause in (self.model.from_clause, self.parent.from_clause):
ref = from_clause.get_column(self.refs[0].name)
if ref is not None:
return ref
def load(self, parent):
if not self.model:
self.parent = parent
name = '_{}_{}'.format(parent.name, self.name)
if name in model_registry:
cls = model_registry[name]
else:
base = model_registry[self.model_name]
cls = type(name,
(base,),
{'type_': base.get_type(),
'from_': base.get_from_aliases(self.name)})
self.model = cls()
self.filter_clause = self.get_filter_clause()