OSDN Git Service

implement select, select_1st partialy
authorHiromichi MATSUSHIMA <hirom@office-sv.osdn.jp>
Wed, 22 Jun 2011 10:33:57 +0000 (19:33 +0900)
committerHiromichi MATSUSHIMA <hirom@office-sv.osdn.jp>
Wed, 22 Jun 2011 10:33:57 +0000 (19:33 +0900)
htmltree.py

index d717468..4e50bb3 100644 (file)
@@ -97,6 +97,7 @@ class HTMLElement(list):
         strs = [f(x,self.attrs[x]) for x in self.attrs]
         return " ".join(strs)
 
+    # basic acquision functions
     def get_attribute(self, attr, default=None):
         """returns given attribute's value."""
         return self.attrs.get(attr, default)
@@ -109,6 +110,12 @@ class HTMLElement(list):
         """returns content in the tag."""
         return self._text
 
+    def inner_html(self):
+        "returns inner html"
+        rn = HTMLRenderer()
+        return rn.render_inner(self)
+
+    # navigation functions
     def parent(self):
         """returns tag's parent element."""
         return self._parent
@@ -121,6 +128,7 @@ class HTMLElement(list):
         """returns tag's previous element."""
         return self._prev_elem
 
+    # basic query functions
     def get_elements_by_name(self, name):
         buf = []
         self._r_get_elements_by_name(name, buf)
@@ -142,10 +150,43 @@ class HTMLElement(list):
         #raise HTMLElementError("Element not found")
         return None
 
+    def get_elements_by_class(self, cls):
+        buf = []
+        self._r_get_elements_by_class(cls, buf)
+        return buf
+
+    def _r_get_elements_by_class(self, cls, buf):
+        if self.get_attribute("class") == cls:
+            buf.append(self)
+        for i in self:
+            i._r_get_elements_by_class(cls, buf)
+
+    # query functions
+    # TODO: this function is under implementing...
     def select(self, expr):
         terms = expr.strip().split()
         if len(terms) == 0:
             return []
+        results = self
+        # at first, select #id
+        ids = [x[1:] for x in terms if x[0] == "#"]
+        if len(ids) != 0:
+            results = [results.get_element_by_id(ids[-1]),]
+        # next, select .class
+        cs = [x[1:] for x in terms if x[0] == "."]
+        t = []
+        if len(cs) != 0:
+            for c in cs:
+                t.extend(results.get_elements_by_class(c))
+            results = t
+        return results
+
+    def select_1st(self, expr):
+        r = self.select(expr)
+        if len(r) == 0:
+            return None
+        else:
+            return r[0]
 
     def select_by_name2(self, term1, term2):
         tbl = self.get_elements_by_name(term1)
@@ -155,6 +196,7 @@ class HTMLElement(list):
             buf.extend(st)
         return buf
 
+    # is_* functions
     def is_text(self):
         return self.type == HTMLElement.TEXT
 
@@ -172,6 +214,7 @@ class HTMLElement(list):
             p = p.parent()
         return False
 
+    # mmmh....
     def trace_back(self, tag):
         """ regexp string => list"""
         p = self.parent()
@@ -183,10 +226,6 @@ class HTMLElement(list):
             p = p.parent()
         return result
 
-    def inner_html(self):
-        "returns inner html"
-        rn = HTMLRenderer()
-        return rn.render_inner(self)
 
 class HTMLTreeError(Exception):
     def __init__(self, msg, lineno, offset):