OSDN Git Service

append TEXTRenderer and HTMLElement.inner_text()
[htmltree-py/htmltree.git] / htmltree.py
index 2403c0b..fea5a7b 100644 (file)
@@ -8,12 +8,111 @@ by hylom <hylomm@@single_at_mark@@gmail.com>
 import HTMLParser
 import re
 
+class HTMLElementError(Exception):
+    def __init__(self, msg, elem):
+        self.msg = msg
+        self.elem = elem
+
+    def __repr__(self):
+        str = "HTML Element Error: %s in %s" % (self.msg, self.elem)
+        return str
+
+class Renderer(object):
+    """HTMLElement Render base class."""
+    def attrs2str(self, elem):
+        strs = []
+        for attr in elem.attrs:
+            if elem.attrs[attr] == None:
+                strs.append(attr)
+            elif "'" in elem.attrs[attr]:
+                strs.append('%s="%s"' % (attr, elem.attrs[attr]))
+            else:
+                strs.append("%s='%s'" % (attr, elem.attrs[attr]))
+        strs.insert(0, "")
+        return " ".join(strs)
+
+class HTMLRenderer(Renderer):
+    """Render HTMLElement as HTML."""
+    # TODO: check tags not need to close more strict...
+    UNCLOSABLE_TAGS = ["br", "link", "meta", "img"]
+
+    def render_inner(self, elem):
+        texts = []
+        for child in elem:
+            self._recursive(child, texts)
+        return "".join(texts)
+
+    def render(self, elem):
+        texts = []
+        self._recursive(elem, texts)
+        return "".join(texts)
+
+    def _recursive(self, elem, texts):
+        if elem.is_tag():
+            texts.append("<" + elem.name + self.attrs2str(elem) + ">")
+            for child in elem:
+                self._recursive(child, texts)
+            if not elem.name in self.UNCLOSABLE_TAGS:
+                texts.append("</" + elem.name + ">")
+        elif elem.is_text():
+            if elem.text():
+                texts.append(elem.text())
+        elif elem.is_root():
+            for child in elem:
+                self._recursive(child, texts)
+        elif elem.is_decl():
+            texts.append("<!" + elem.name + ">")
+
+
+class TEXTRenderer(Renderer):
+    """Render HTMLElement as TEXT."""
+    # TODO: check tags not need to close more strict...
+    UNCLOSABLE_TAGS = ["br", "link", "meta", "img"]
+
+    def render_inner(self, elem):
+        texts = []
+        for child in elem:
+            self._recursive(child, texts)
+        return "".join(texts)
+
+    def render(self, elem):
+        texts = []
+        self._recursive(elem, texts)
+        return "".join(texts)
+
+    def _recursive(self, elem, texts):
+        if elem.is_tag():
+            for child in elem:
+                self._recursive(child, texts)
+        elif elem.is_text():
+            if elem.text():
+                texts.append(elem.text())
+        elif elem.is_root():
+            for child in elem:
+                self._recursive(child, texts)
+
 class HTMLElement(list):
-    "HTML element object"
+    """HTML element object to use as tree nodes."""
     ROOT = 0
     TAG = 100
     TEXT = 200
+    DECL = 300
+
     def __init__(self, type, name="", attrs={}):
+        """
+        create HTMLElement object.
+
+        Arguments:
+        type -- element type. HTMLElement.(ROOT|TAG|TEXT)
+        name -- element name (default: "")
+        attrs -- dict of attributes (default:{})
+
+        Example:
+        attr = dict(href="http://example.com/", target="_blank")
+        e = HTMLElement(HTMLElement.TAG, "a", attr)
+        # 'e' means <a href="http://example.com/" target="_blank">
+        """
+
         self.type = type
         self.name = name
         self.attrs = dict(attrs)
@@ -25,10 +124,15 @@ class HTMLElement(list):
     def __repr__(self):
         if self.type == HTMLElement.TAG:
             return "<TAG:%s %s>" % (self.name, self._attrs2str())
+        elif self.type == HTMLElement.DECL:
+            return "<DECL:'%s'>" % self.name
         elif self.type == HTMLElement.TEXT:
             return "<TEXT:'%s'>" % self._text
         else:
-            return None
+            return "<UNKNOWN>"
+
+    def __eq__(self, other):
+        return id(self) == id(other)
 
     def _attrs2str(self):
         str = []
@@ -37,18 +141,47 @@ class HTMLElement(list):
         strs = [f(x,self.attrs[x]) for x in self.attrs]
         return " ".join(strs)
 
+    # basic acquision functions
+    def get_attribute(self, attr, default=None):
+        """returns given attribute's value."""
+        return self.attrs.get(attr, default)
+
+    def attr(self, attr, default=None):
+        """returns given attribute's value."""
+        return self.attrs.get(attr, default)
+
+    def has_attribute(self, attr):
+        """returns True if element has "attr" attribute."""
+        return attr in self.attrs
+
     def text(self):
+        """returns content in the tag."""
         return self._text
 
+    def inner_html(self):
+        "returns inner html"
+        rn = HTMLRenderer()
+        return rn.render_inner(self)
+
+    def inner_text(self):
+        "returns inner text"
+        rn = TEXTRenderer()
+        return rn.render_inner(self)
+
+    # navigation functions
     def parent(self):
+        """returns tag's parent element."""
         return self._parent
 
     def next(self):
+        """returns tag's next element."""
         return self._next_elem
 
     def prev(self):
+        """returns tag's previous element."""
         return self._prev_elem
 
+    # basic query functions
     def get_elements_by_name(self, name):
         buf = []
         self._r_get_elements_by_name(name, buf)
@@ -61,18 +194,78 @@ class HTMLElement(list):
             i._r_get_elements_by_name(name, buf)
 
     def get_element_by_id(self, id):
-        if self.attr["id"] == id:
+        if "id" in self.attrs and self.attrs["id"] == id:
             return self
         for i in self:
             e = i.get_element_by_id(id)
             if e != None:
                 return e
+        #raise HTMLElementError("Element not found")
         return None
 
+    def get_elements_by_class(self, cls):
+        buf = []
+        self._r_get_elements_by_class(cls, buf)
+        return buf
+
+    def _r_get_elements_by_class(self, cls, buf):
+        if self.get_attribute("class") == cls:
+            buf.append(self)
+        for i in self:
+            i._r_get_elements_by_class(cls, buf)
+
+    def get_elements(self, name, attrs):
+        elems = self.get_elements_by_name(name)
+        results = []
+        for elem in elems:
+            for name in attrs:
+                if elem.get_attribute(name, "") != attrs[name]:
+                    break
+            else:
+                results.append(elem)
+        return results
+
+    # manipulation functions
+    def append_tag(self, tag, attrs):
+        elem = HTMLElement(HTMLElement.TAG, tag, attrs)
+        self.append(elem)
+
+    def remove_element(self, elem):
+        parent = elem.parent()
+        parent.remove(elem)
+
+    def delete(self):
+        p = self.parent()
+        p.remove(self)
+
+    # query functions
+    # TODO: this function is under implementing...
     def select(self, expr):
         terms = expr.strip().split()
         if len(terms) == 0:
             return []
+        results = self
+        for pat in terms:
+            t = []
+            for elem in results:
+                t.extend(self._select_pattern(pat, elem))
+            results = t
+        return results
+
+    def _select_pattern(self, pat, elem):
+        results = []
+        if pat[0] == "#":
+            results = [elem.get_element_by_id(pat[1:]),]
+        elif pat[0] == ".":
+            results = elem.get_elements_by_class(pat[1:])
+        return [x for x in results if x]
+
+    def select_1st(self, expr):
+        r = self.select(expr)
+        if len(r) == 0:
+            return None
+        else:
+            return r[0]
 
     def select_by_name2(self, term1, term2):
         tbl = self.get_elements_by_name(term1)
@@ -82,6 +275,7 @@ class HTMLElement(list):
             buf.extend(st)
         return buf
 
+    # is_* functions
     def is_text(self):
         return self.type == HTMLElement.TEXT
 
@@ -91,6 +285,9 @@ class HTMLElement(list):
     def is_root(self):
         return self.type == HTMLElement.ROOT
 
+    def is_decl(self):
+        return self.type == HTMLElement.DECL
+
     def is_descendant(self, tagname):
         p = self.parent()
         while p != None:
@@ -99,6 +296,7 @@ class HTMLElement(list):
             p = p.parent()
         return False
 
+    # mmmh....
     def trace_back(self, tag):
         """ regexp string => list"""
         p = self.parent()
@@ -110,6 +308,7 @@ class HTMLElement(list):
             p = p.parent()
         return result
 
+
 class HTMLTreeError(Exception):
     def __init__(self, msg, lineno, offset):
         self.msg = msg
@@ -119,7 +318,14 @@ class HTMLTreeError(Exception):
     def __repr__(self):
         str = "HTML Parse Error: %s , line: %d, char: %d" % (self.msg, self.lineno, self.offset)
         return str
-    
+
+
+def parse(data, charset=None, option=0):
+    "parse HTML and returns HTMLTree object"
+    tree = HTMLTree()
+    tree.parse(data, charset, option)
+    return tree
+
 
 class HTMLTree(HTMLParser.HTMLParser):
     "HTML Tree Builder"
@@ -130,12 +336,24 @@ class HTMLTree(HTMLParser.HTMLParser):
     JOIN_TEXT    = 0x0040
 
     TRUNC_BR = 0x0100
+    # TODO: check tags not need to close more strict...
+    UNCLOSABLE_TAGS = ["br", "link", "meta", "img", "input"]
 
     def __init__(self):
         "Constructor"
         HTMLParser.HTMLParser.__init__(self)
 
     def parse(self, data, charset=None, option=0):
+        """
+        Parse given HTML.
+
+        Arguments:
+        data -- HTML to parse
+        charset -- charset of HTML (default: None)
+        option -- option (default: 0, meaning none)
+        
+        """
+
         self.charset = charset
         self._htmlroot = HTMLElement(HTMLElement.ROOT)
         self._cursor = self._htmlroot
@@ -193,7 +411,7 @@ class HTMLTree(HTMLParser.HTMLParser):
     # Handlers
     def handle_starttag(self, tag, attrs):
         # some tags treat as start-end tag.
-        if tag in ["br",]:
+        if tag in self.UNCLOSABLE_TAGS:
             return self.handle_startendtag(tag, attrs)
             
         elem = HTMLElement(HTMLElement.TAG, tag, attrs)
@@ -210,7 +428,7 @@ class HTMLTree(HTMLParser.HTMLParser):
 
     def handle_endtag(self, tag):
         # some tags treat as start-end tag.
-        if tag in ["br",]:
+        if tag in self.UNCLOSABLE_TAGS:
             return
 
         self._cursor = self._cursor.parent()
@@ -227,12 +445,33 @@ class HTMLTree(HTMLParser.HTMLParser):
 
         elem = HTMLElement(HTMLElement.TEXT)
         elem._parent = self._cursor
+
+        # text encode check and convert.
+        # if charset is given, convert text to unicode type.
         if self.charset:
-            elem._text = unicode(data, self.charset).encode("utf-8")
+            try:
+                elem._text = unicode(data, self.charset)
+            except TypeError:
+                # self.charset is utf-8.
+                elem._text = data
         else:
+            # treat as unicode input
             elem._text = data
         self._cursor.append(elem)
 
+    def handle_entityref(self, name):
+        data = "&" + name + ";"
+        self.handle_data(data)
+
+    def handle_charref(self, ref):
+        data = "&#" + ref + ";"
+        self.handle_data(data)
+
+    def handle_decl(self, decl):
+        elem = HTMLElement(HTMLElement.DECL, decl)
+        elem._parent = self._cursor
+        self._cursor.append(elem)
+
     # Accessor
     def root(self):
         return self._htmlroot