from django.contrib.contenttypes.generic import GenericRelation from django.contrib.contenttypes.models import ContentType from django.db import models from django.db.models.fields.related import ManyToManyRel, RelatedField, add_lazy_relation from django.db.models.related import RelatedObject from django.utils.text import capfirst from django.utils.translation import ugettext_lazy as _ from taggit.forms import TagField from taggit.models import TaggedItem, GenericTaggedItemBase from taggit.utils import require_instance_manager try: all except NameError: # 2.4 compat try: from django.utils.itercompat import all except ImportError: # 1.1.X compat def all(iterable): for item in iterable: if not item: return False return True class TaggableRel(ManyToManyRel): def __init__(self): self.related_name = None self.limit_choices_to = {} self.symmetrical = True self.multiple = True self.through = None class TaggableManager(RelatedField): def __init__(self, verbose_name=_("Tags"), help_text=_("A comma-separated list of tags."), through=None, blank=False): self.through = through or TaggedItem self.rel = TaggableRel() self.verbose_name = verbose_name self.help_text = help_text self.blank = blank self.editable = True self.unique = False self.creates_table = False self.db_column = None self.choices = None self.serialize = False self.null = True self.creation_counter = models.Field.creation_counter models.Field.creation_counter += 1 def __get__(self, instance, model): if instance is not None and instance.pk is None: raise ValueError("%s objects need to have a primary key value " "before you can access their tags." % model.__name__) manager = _TaggableManager( through=self.through, model=model, instance=instance ) return manager def contribute_to_class(self, cls, name): self.name = self.column = name self.model = cls cls._meta.add_field(self) setattr(cls, name, self) if not cls._meta.abstract: if isinstance(self.through, basestring): def resolve_related_class(field, model, cls): self.through = model self.post_through_setup(cls) add_lazy_relation( cls, self, self.through, resolve_related_class ) else: self.post_through_setup(cls) def post_through_setup(self, cls): self.use_gfk = ( self.through is None or issubclass(self.through, GenericTaggedItemBase) ) self.rel.to = self.through._meta.get_field("tag").rel.to if self.use_gfk: tagged_items = GenericRelation(self.through) tagged_items.contribute_to_class(cls, "tagged_items") def save_form_data(self, instance, value): getattr(instance, self.name).set(*value) def formfield(self, form_class=TagField, **kwargs): defaults = { "label": capfirst(self.verbose_name), "help_text": self.help_text, "required": not self.blank } defaults.update(kwargs) return form_class(**defaults) def value_from_object(self, instance): if instance.pk: return self.through.objects.filter(**self.through.lookup_kwargs(instance)) return self.through.objects.none() def related_query_name(self): return self.model._meta.module_name def m2m_reverse_name(self): return self.through._meta.get_field_by_name("tag")[0].column def m2m_target_field_name(self): return self.model._meta.pk.name def m2m_reverse_target_field_name(self): return self.rel.to._meta.pk.name def m2m_column_name(self): if self.use_gfk: return self.through._meta.virtual_fields[0].fk_field return self.through._meta.get_field('content_object').column def db_type(self, connection=None): return None def m2m_db_table(self): return self.through._meta.db_table def extra_filters(self, pieces, pos, negate): if negate or not self.use_gfk: return [] prefix = "__".join(["tagged_items"] + pieces[:pos-2]) cts = map(ContentType.objects.get_for_model, _get_subclasses(self.model)) if len(cts) == 1: return [("%s__content_type" % prefix, cts[0])] return [("%s__content_type__in" % prefix, cts)] def bulk_related_objects(self, new_objs, using): return [] class _TaggableManager(models.Manager): def __init__(self, through, model, instance): self.through = through self.model = model self.instance = instance def get_query_set(self): return self.through.tags_for(self.model, self.instance) def _lookup_kwargs(self): return self.through.lookup_kwargs(self.instance) @require_instance_manager def add(self, *tags): str_tags = set([ t for t in tags if not isinstance(t, self.through.tag_model()) ]) tag_objs = set(tags) - str_tags # If str_tags has 0 elements Django actually optimizes that to not do a # query. Malcolm is very smart. existing = self.through.tag_model().objects.filter( name__in=str_tags ) tag_objs.update(existing) for new_tag in str_tags - set(t.name for t in existing): tag_objs.add(self.through.tag_model().objects.create(name=new_tag)) for tag in tag_objs: self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs()) @require_instance_manager def set(self, *tags): self.clear() self.add(*tags) @require_instance_manager def remove(self, *tags): self.through.objects.filter(**self._lookup_kwargs()).filter( tag__name__in=tags).delete() @require_instance_manager def clear(self): self.through.objects.filter(**self._lookup_kwargs()).delete() def most_common(self): return self.get_query_set().annotate( num_times=models.Count(self.through.tag_relname()) ).order_by('-num_times') @require_instance_manager def similar_objects(self): lookup_kwargs = self._lookup_kwargs() lookup_keys = sorted(lookup_kwargs) qs = self.through.objects.values(*lookup_kwargs.keys()) qs = qs.annotate(n=models.Count('pk')) qs = qs.exclude(**lookup_kwargs) qs = qs.filter(tag__in=self.all()) qs = qs.order_by('-n') # TODO: This all feels like a bit of a hack. items = {} if len(lookup_keys) == 1: # Can we do this without a second query by using a select_related() # somehow? f = self.through._meta.get_field_by_name(lookup_keys[0])[0] objs = f.rel.to._default_manager.filter(**{ "%s__in" % f.rel.field_name: [r["content_object"] for r in qs] }) for obj in objs: items[(getattr(obj, f.rel.field_name),)] = obj else: preload = {} for result in qs: preload.setdefault(result['content_type'], set()) preload[result["content_type"]].add(result["object_id"]) for ct, obj_ids in preload.iteritems(): ct = ContentType.objects.get_for_id(ct) for obj in ct.model_class()._default_manager.filter(pk__in=obj_ids): items[(ct.pk, obj.pk)] = obj results = [] for result in qs: obj = items[ tuple(result[k] for k in lookup_keys) ] obj.similar_tags = result["n"] results.append(obj) return results def _get_subclasses(model): subclasses = [model] for f in model._meta.get_all_field_names(): field = model._meta.get_field_by_name(f)[0] if (isinstance(field, RelatedObject) and getattr(field.field.rel, "parent_link", None)): subclasses.extend(_get_subclasses(field.model)) return subclasses