# Copyright: See the LICENSE file.
import itertools
import logging
from . import enums, errors, utils
logger = logging.getLogger('factory.generate')
class BaseDeclaration(utils.OrderedBase):
"""A factory declaration.
Declarations mark an attribute as needing lazy evaluation.
This allows them to refer to attributes defined by other BaseDeclarations
in the same factory.
"""
FACTORY_BUILDER_PHASE = enums.BuilderPhase.ATTRIBUTE_RESOLUTION
#: Whether to unroll the context before evaluating the declaration.
#: Set to False on declarations that perform their own unrolling.
UNROLL_CONTEXT_BEFORE_EVALUATION = True
def unroll_context(self, instance, step, context):
if not self.UNROLL_CONTEXT_BEFORE_EVALUATION:
return context
if not any(enums.get_builder_phase(v) for v in context.values()):
# Optimization for simple contexts - don't do anything.
return context
import factory.base
subfactory = factory.base.DictFactory
return step.recurse(subfactory, context, force_sequence=step.sequence)
def evaluate(self, instance, step, extra):
"""Evaluate this declaration.
Args:
instance (builder.Resolver): The object holding currently computed
attributes
step: a factory.builder.BuildStep
extra (dict): additional, call-time added kwargs
for the step.
"""
raise NotImplementedError('This is an abstract method')
class OrderedDeclaration(BaseDeclaration):
"""Compatibility"""
# FIXME(rbarrois)
class LazyFunction(BaseDeclaration):
"""Simplest BaseDeclaration computed by calling the given function.
Attributes:
function (function): a function without arguments and
returning the computed value.
"""
def __init__(self, function, *args, **kwargs):
super().__init__(*args, **kwargs)
self.function = function
def evaluate(self, instance, step, extra):
logger.debug("LazyFunction: Evaluating %r on %r", self.function, step)
return self.function()
class LazyAttribute(BaseDeclaration):
"""Specific BaseDeclaration computed using a lambda.
Attributes:
function (function): a function, expecting the current LazyStub and
returning the computed value.
"""
def __init__(self, function, *args, **kwargs):
super().__init__(*args, **kwargs)
self.function = function
def evaluate(self, instance, step, extra):
logger.debug("LazyAttribute: Evaluating %r on %r", self.function, instance)
return self.function(instance)
class _UNSPECIFIED:
pass
def deepgetattr(obj, name, default=_UNSPECIFIED):
"""Try to retrieve the given attribute of an object, digging on '.'.
This is an extended getattr, digging deeper if '.' is found.
Args:
obj (object): the object of which an attribute should be read
name (str): the name of an attribute to look up.
default (object): the default value to use if the attribute wasn't found
Returns:
the attribute pointed to by 'name', splitting on '.'.
Raises:
AttributeError: if obj has no 'name' attribute.
"""
try:
if '.' in name:
attr, subname = name.split('.', 1)
return deepgetattr(getattr(obj, attr), subname, default)
else:
return getattr(obj, name)
except AttributeError:
if default is _UNSPECIFIED:
raise
else:
return default
class SelfAttribute(BaseDeclaration):
"""Specific BaseDeclaration copying values from other fields.
If the field name starts with two dots or more, the lookup will be anchored
in the related 'parent'.
Attributes:
depth (int): the number of steps to go up in the containers chain
attribute_name (str): the name of the attribute to copy.
default (object): the default value to use if the attribute doesn't
exist.
"""
def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs):
super().__init__(*args, **kwargs)
depth = len(attribute_name) - len(attribute_name.lstrip('.'))
attribute_name = attribute_name[depth:]
self.depth = depth
self.attribute_name = attribute_name
self.default = default
def evaluate(self, instance, step, extra):
if self.depth > 1:
# Fetching from a parent
target = step.chain[self.depth - 1]
else:
target = instance
logger.debug("SelfAttribute: Picking attribute %r on %r", self.attribute_name, target)
return deepgetattr(target, self.attribute_name, self.default)
def __repr__(self):
return '<%s(%r, default=%r)>' % (
self.__class__.__name__,
self.attribute_name,
self.default,
)
class Iterator(BaseDeclaration):
"""Fill this value using the values returned by an iterator.
Warning: the iterator should not end !
Attributes:
iterator (iterable): the iterator whose value should be used.
getter (callable or None): a function to parse returned values
"""
def __init__(self, iterator, cycle=True, getter=None):
super().__init__()
self.getter = getter
self.iterator = None
if cycle:
self.iterator_builder = lambda: utils.ResetableIterator(itertools.cycle(iterator))
else:
self.iterator_builder = lambda: utils.ResetableIterator(iterator)
def evaluate(self, instance, step, extra):
# Begin unrolling as late as possible.
# This helps with ResetableIterator(MyModel.objects.all())
if self.iterator is None:
self.iterator = self.iterator_builder()
logger.debug("Iterator: Fetching next value from %r", self.iterator)
value = next(iter(self.iterator))
if self.getter is None:
return value
return self.getter(value)
def reset(self):
"""Reset the internal iterator."""
if self.iterator is not None:
self.iterator.reset()
class Sequence(BaseDeclaration):
"""Specific BaseDeclaration to use for 'sequenced' fields.
These fields are typically used to generate increasing unique values.
Attributes:
function (function): A function, expecting the current sequence counter
and returning the computed value.
"""
def __init__(self, function):
super().__init__()
self.function = function
def evaluate(self, instance, step, extra):
logger.debug("Sequence: Computing next value of %r for seq=%s", self.function, step.sequence)
return self.function(int(step.sequence))
class LazyAttributeSequence(Sequence):
"""Composite of a LazyAttribute and a Sequence.
Attributes:
function (function): A function, expecting the current LazyStub and the
current sequence counter.
type (function): A function converting an integer into the expected kind
of counter for the 'function' attribute.
"""
def evaluate(self, instance, step, extra):
logger.debug(
"LazyAttributeSequence: Computing next value of %r for seq=%s, obj=%r",
self.function, step.sequence, instance)
return self.function(instance, int(step.sequence))
class ContainerAttribute(BaseDeclaration):
"""Variant of LazyAttribute, also receives the containers of the object.
Attributes:
function (function): A function, expecting the current LazyStub and the
(optional) object having a subfactory containing this attribute.
strict (bool): Whether evaluating should fail when the containers are
not passed in (i.e used outside a SubFactory).
"""
def __init__(self, function, strict=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.function = function
self.strict = strict
def evaluate(self, instance, step, extra):
"""Evaluate the current ContainerAttribute.
Args:
obj (LazyStub): a lazy stub of the object being constructed, if
needed.
containers (list of LazyStub): a list of lazy stubs of factories
being evaluated in a chain, each item being a future field of
next one.
"""
# Strip the current instance from the chain
chain = step.chain[1:]
if self.strict and not chain:
raise TypeError(
"A ContainerAttribute in 'strict' mode can only be used "
"within a SubFactory.")
return self.function(instance, chain)
class ParameteredAttribute(BaseDeclaration):
"""Base class for attributes expecting parameters.
Attributes:
defaults (dict): Default values for the parameters.
May be overridden by call-time parameters.
Class attributes:
CONTAINERS_FIELD (str): name of the field, if any, where container
information (e.g for SubFactory) should be stored. If empty,
containers data isn't merged into generate() parameters.
"""
CONTAINERS_FIELD = '__containers'
# Whether to add the current object to the stack of containers
EXTEND_CONTAINERS = False
def __init__(self, **kwargs):
super().__init__()
self.defaults = kwargs
def _prepare_containers(self, obj, containers=()):
if self.EXTEND_CONTAINERS:
return (obj,) + tuple(containers)
return containers
def evaluate(self, instance, step, extra):
"""Evaluate the current definition and fill its attributes.
Uses attributes definition in the following order:
- values defined when defining the ParameteredAttribute
- additional values defined when instantiating the containing factory
Args:
instance (builder.Resolver): The object holding currently computed
attributes
step: a factory.builder.BuildStep
extra (dict): additional, call-time added kwargs
for the step.
"""
defaults = dict(self.defaults)
if extra:
defaults.update(extra)
return self.generate(step, defaults)
def generate(self, step, params):
"""Actually generate the related attribute.
Args:
sequence (int): the current sequence number
obj (LazyStub): the object being constructed
create (bool): whether the calling factory was in 'create' or
'build' mode
params (dict): parameters inherited from init and evaluation-time
overrides.
Returns:
Computed value for the current declaration.
"""
raise NotImplementedError()
class _FactoryWrapper:
"""Handle a 'factory' arg.
Such args can be either a Factory subclass, or a fully qualified import
path for that subclass (e.g 'myapp.factories.MyFactory').
"""
def __init__(self, factory_or_path):
self.factory = None
self.module = self.name = ''
if isinstance(factory_or_path, type):
self.factory = factory_or_path
else:
if not (isinstance(factory_or_path, str) and '.' in factory_or_path):
raise ValueError(
"A factory= argument must receive either a class "
"or the fully qualified path to a Factory subclass; got "
"%r instead." % factory_or_path)
self.module, self.name = factory_or_path.rsplit('.', 1)
def get(self):
if self.factory is None:
self.factory = utils.import_object(
self.module,
self.name,
)
return self.factory
def __repr__(self):
if self.factory is None:
return '<_FactoryImport: %s.%s>' % (self.module, self.name)
else:
return '<_FactoryImport: %s>' % self.factory.__class__
class SubFactory(ParameteredAttribute):
"""Base class for attributes based upon a sub-factory.
Attributes:
defaults (dict): Overrides to the defaults defined in the wrapped
factory
factory (base.Factory): the wrapped factory
"""
EXTEND_CONTAINERS = True
FORCE_SEQUENCE = False
UNROLL_CONTEXT_BEFORE_EVALUATION = False
def __init__(self, factory, **kwargs):
super().__init__(**kwargs)
self.factory_wrapper = _FactoryWrapper(factory)
def get_factory(self):
"""Retrieve the wrapped factory.Factory subclass."""
return self.factory_wrapper.get()
def generate(self, step, params):
"""Evaluate the current definition and fill its attributes.
Args:
step: a factory.builder.BuildStep
params (dict): additional, call-time added kwargs
for the step.
"""
subfactory = self.get_factory()
logger.debug(
"SubFactory: Instantiating %s.%s(%s), create=%r",
subfactory.__module__, subfactory.__name__,
utils.log_pprint(kwargs=params),
step,
)
force_sequence = step.sequence if self.FORCE_SEQUENCE else None
return step.recurse(subfactory, params, force_sequence=force_sequence)
class Dict(SubFactory):
"""Fill a dict with usual declarations."""
FORCE_SEQUENCE = True
def __init__(self, params, dict_factory='factory.DictFactory'):
super().__init__(dict_factory, **dict(params))
class List(SubFactory):
"""Fill a list with standard declarations."""
FORCE_SEQUENCE = True
def __init__(self, params, list_factory='factory.ListFactory'):
params = {str(i): v for i, v in enumerate(params)}
super().__init__(list_factory, **params)
# Parameters
# ==========
class Skip:
def __bool__(self):
return False
SKIP = Skip()
class Maybe(BaseDeclaration):
def __init__(self, decider, yes_declaration=SKIP, no_declaration=SKIP):
super().__init__()
if enums.get_builder_phase(decider) is None:
# No builder phase => flat value
decider = SelfAttribute(decider, default=None)
self.decider = decider
self.yes = yes_declaration
self.no = no_declaration
phases = {
'yes_declaration': enums.get_builder_phase(yes_declaration),
'no_declaration': enums.get_builder_phase(no_declaration),
}
used_phases = {phase for phase in phases.values() if phase is not None}
if len(used_phases) > 1:
raise TypeError("Inconsistent phases for %r: %r" % (self, phases))
self.FACTORY_BUILDER_PHASE = used_phases.pop() if used_phases else enums.BuilderPhase.ATTRIBUTE_RESOLUTION
def call(self, instance, step, context):
decider_phase = enums.get_builder_phase(self.decider)
if decider_phase == enums.BuilderPhase.ATTRIBUTE_RESOLUTION:
# Note: we work on the *builder stub*, not on the actual instance.
# This gives us access to all Params-level definitions.
choice = self.decider.evaluate(instance=step.stub, step=step, extra=context.extra)
else:
assert decider_phase == enums.BuilderPhase.POST_INSTANTIATION
choice = self.decider.call(instance, step, context)
target = self.yes if choice else self.no
if enums.get_builder_phase(target) == enums.BuilderPhase.POST_INSTANTIATION:
return target.call(
instance=instance,
step=step,
context=context,
)
else:
# Flat value (can't be ATTRIBUTE_RESOLUTION, checked in __init__)
return target
def evaluate(self, instance, step, extra):
choice = self.decider.evaluate(instance=instance, step=step, extra={})
target = self.yes if choice else self.no
if isinstance(target, BaseDeclaration):
return target.evaluate(
instance=instance,
step=step,
extra=extra,
)
else:
# Flat value (can't be POST_INSTANTIATION, checked in __init__)
return target
def __repr__(self):
return 'Maybe(%r, yes=%r, no=%r)' % (self.decider, self.yes, self.no)
class Parameter(utils.OrderedBase):
"""A complex parameter, to be used in a Factory.Params section.
Must implement:
- A "compute" function, performing the actual declaration override
- Optionally, a get_revdeps() function (to compute other parameters it may alter)
"""
def as_declarations(self, field_name, declarations):
"""Compute the overrides for this parameter.
Args:
- field_name (str): the field this parameter is installed at
- declarations (dict): the global factory declarations
Returns:
dict: the declarations to override
"""
raise NotImplementedError()
def get_revdeps(self, parameters):
"""Retrieve the list of other parameters modified by this one."""
return []
class SimpleParameter(Parameter):
def __init__(self, value):
super().__init__()
self.value = value
def as_declarations(self, field_name, declarations):
return {
field_name: self.value,
}
@classmethod
def wrap(cls, value):
if not isinstance(value, Parameter):
return cls(value)
value.touch_creation_counter()
return value
class Trait(Parameter):
"""The simplest complex parameter, it enables a bunch of new declarations based on a boolean flag."""
def __init__(self, **overrides):
super().__init__()
self.overrides = overrides
def as_declarations(self, field_name, declarations):
overrides = {}
for maybe_field, new_value in self.overrides.items():
overrides[maybe_field] = Maybe(
decider=SelfAttribute(
'%s.%s' % (
'.' * maybe_field.count(enums.SPLITTER),
field_name,
),
default=False,
),
yes_declaration=new_value,
no_declaration=declarations.get(maybe_field, SKIP),
)
return overrides
def get_revdeps(self, parameters):
"""This might alter fields it's injecting."""
return [param for param in parameters if param in self.overrides]
def __repr__(self):
return '%s(%s)' % (
self.__class__.__name__,
', '.join('%s=%r' % t for t in self.overrides.items())
)
# Post-generation
# ===============
class PostGenerationDeclaration(BaseDeclaration):
"""Declarations to be called once the model object has been generated."""
FACTORY_BUILDER_PHASE = enums.BuilderPhase.POST_INSTANTIATION
def call(self, instance, step, context): # pragma: no cover
"""Call this hook; no return value is expected.
Args:
obj (object): the newly generated object
create (bool): whether the object was 'built' or 'created'
context: a builder.PostGenerationContext containing values
extracted from the containing factory's declaration
"""
raise NotImplementedError()
class PostGeneration(PostGenerationDeclaration):
"""Calls a given function once the object has been generated."""
def __init__(self, function):
super().__init__()
self.function = function
def call(self, instance, step, context):
logger.debug(
"PostGeneration: Calling %s.%s(%s)",
self.function.__module__,
self.function.__name__,
utils.log_pprint(
(instance, step),
context._asdict(),
),
)
create = step.builder.strategy == enums.CREATE_STRATEGY
return self.function(
instance, create, context.value, **context.extra)
class RelatedFactory(PostGenerationDeclaration):
"""Calls a factory once the object has been generated.
Attributes:
factory (Factory): the factory to call
defaults (dict): extra declarations for calling the related factory
name (str): the name to use to refer to the generated object when
calling the related factory
"""
UNROLL_CONTEXT_BEFORE_EVALUATION = False
def __init__(self, factory, factory_related_name='', **defaults):
super().__init__()
self.name = factory_related_name
self.defaults = defaults
self.factory_wrapper = _FactoryWrapper(factory)
def get_factory(self):
"""Retrieve the wrapped factory.Factory subclass."""
return self.factory_wrapper.get()
def call(self, instance, step, context):
factory = self.get_factory()
if context.value_provided:
# The user passed in a custom value
logger.debug(
"RelatedFactory: Using provided %r instead of generating %s.%s.",
context.value,
factory.__module__, factory.__name__,
)
return context.value
passed_kwargs = dict(self.defaults)
passed_kwargs.update(context.extra)
if self.name:
passed_kwargs[self.name] = instance
logger.debug(
"RelatedFactory: Generating %s.%s(%s)",
factory.__module__,
factory.__name__,
utils.log_pprint((step,), passed_kwargs),
)
return step.recurse(factory, passed_kwargs)
class RelatedFactoryList(RelatedFactory):
"""Calls a factory 'size' times once the object has been generated.
Attributes:
factory (Factory): the factory to call "size-times"
defaults (dict): extra declarations for calling the related factory
factory_related_name (str): the name to use to refer to the generated
object when calling the related factory
size (int|lambda): the number of times 'factory' is called, ultimately
returning a list of 'factory' objects w/ size 'size'.
"""
def __init__(self, factory, factory_related_name='', size=2, **defaults):
self.size = size
super().__init__(factory, factory_related_name, **defaults)
def call(self, instance, step, context):
return [super(RelatedFactoryList, self).call(instance, step, context)
for i in range(self.size if isinstance(self.size, int)
else self.size())]
class NotProvided:
pass
class PostGenerationMethodCall(PostGenerationDeclaration):
"""Calls a method of the generated object.
Attributes:
method_name (str): the method to call
method_args (list): arguments to pass to the method
method_kwargs (dict): keyword arguments to pass to the method
Example:
class UserFactory(factory.Factory):
...
password = factory.PostGenerationMethodCall('set_pass', password='')
"""
def __init__(self, method_name, *args, **kwargs):
super().__init__()
if len(args) > 1:
raise errors.InvalidDeclarationError(
"A PostGenerationMethodCall can only handle 1 positional argument; "
"please provide other parameters through keyword arguments."
)
self.method_name = method_name
self.method_arg = args[0] if args else NotProvided
self.method_kwargs = kwargs
def call(self, instance, step, context):
if not context.value_provided:
if self.method_arg is NotProvided:
args = ()
else:
args = (self.method_arg,)
else:
args = (context.value,)
kwargs = dict(self.method_kwargs)
kwargs.update(context.extra)
method = getattr(instance, self.method_name)
logger.debug(
"PostGenerationMethodCall: Calling %r.%s(%s)",
instance,
self.method_name,
utils.log_pprint(args, kwargs),
)
return method(*args, **kwargs)