OSDN Git Service

add 'remove_element' func
[htmltree-py/htmltree.git] / htmltree.py
1 # htmltree.py by hylom
2 # -*- coding: utf-8 -*-
3
4 """htmltree.py - HTML Element-Tree Builder
5 by hylom <hylomm@@single_at_mark@@gmail.com>
6 """
7
8 import HTMLParser
9 import re
10
11 class HTMLElementError(Exception):
12     def __init__(self, msg, elem):
13         self.msg = msg
14         self.elem = elem
15
16     def __repr__(self):
17         str = "HTML Element Error: %s in %s" % (self.msg, self.elem)
18         return str
19
20 class Renderer(object):
21     """HTMLElement Render base class."""
22     def attrs2str(self, elem):
23         strs = []
24         for attr in elem.attrs:
25             if elem.attrs[attr] == None:
26                 strs.append(attr)
27             elif "'" in elem.attrs[attr]:
28                 strs.append('%s="%s"' % (attr, elem.attrs[attr]))
29             else:
30                 strs.append("%s='%s'" % (attr, elem.attrs[attr]))
31         strs.insert(0, "")
32         return " ".join(strs)
33
34 class HTMLRenderer(Renderer):
35     """Render HTMLElement as HTML."""
36     # TODO: check tags not need to close more strict...
37     UNCLOSABLE_TAGS = ["br", "link", "meta", "img"]
38
39     def render_inner(self, elem):
40         texts = []
41         for child in elem:
42             self._recursive(child, texts)
43         return "".join(texts)
44
45     def render(self, elem):
46         texts = []
47         self._recursive(elem, texts)
48         return "".join(texts)
49
50     def _recursive(self, elem, texts):
51         if elem.is_tag():
52             texts.append("<" + elem.name + self.attrs2str(elem) + ">")
53             for child in elem:
54                 self._recursive(child, texts)
55             if not elem.name in self.UNCLOSABLE_TAGS:
56                 texts.append("</" + elem.name + ">")
57         elif elem.is_text():
58             if elem.text():
59                 texts.append(elem.text())
60         elif elem.is_root():
61             for child in elem:
62                 self._recursive(child, texts)
63         
64
65 class HTMLElement(list):
66     """HTML element object to use as tree nodes."""
67     ROOT = 0
68     TAG = 100
69     TEXT = 200
70
71     def __init__(self, type, name="", attrs={}):
72         """
73         create HTMLElement object.
74
75         Arguments:
76         type -- element type. HTMLElement.(ROOT|TAG|TEXT)
77         name -- element name (default: "")
78         attrs -- dict of attributes (default:{})
79
80         Example:
81         attr = dict(href="http://example.com/", target="_blank")
82         e = HTMLElement(HTMLElement.TAG, "a", attr)
83         # 'e' means <a href="http://example.com/" target="_blank">
84         """
85
86         self.type = type
87         self.name = name
88         self.attrs = dict(attrs)
89         self._text = ""
90         self._parent = None
91         self._next_elem = None
92         self._prev_elem = None
93
94     def __repr__(self):
95         if self.type == HTMLElement.TAG:
96             return "<TAG:%s %s>" % (self.name, self._attrs2str())
97         elif self.type == HTMLElement.TEXT:
98             return "<TEXT:'%s'>" % self._text
99         else:
100             return "<UNKNOWN>"
101
102     def _attrs2str(self):
103         str = []
104         f = lambda x,y: x if y == None else "%s='%s'" % (x,y)
105
106         strs = [f(x,self.attrs[x]) for x in self.attrs]
107         return " ".join(strs)
108
109     # basic acquision functions
110     def get_attribute(self, attr, default=None):
111         """returns given attribute's value."""
112         return self.attrs.get(attr, default)
113
114     def has_attribute(self, attr):
115         """returns True if element has "attr" attribute."""
116         return attr in self.attrs
117
118     def text(self):
119         """returns content in the tag."""
120         return self._text
121
122     def inner_html(self):
123         "returns inner html"
124         rn = HTMLRenderer()
125         return rn.render_inner(self)
126
127     # navigation functions
128     def parent(self):
129         """returns tag's parent element."""
130         return self._parent
131
132     def next(self):
133         """returns tag's next element."""
134         return self._next_elem
135
136     def prev(self):
137         """returns tag's previous element."""
138         return self._prev_elem
139
140     # basic query functions
141     def get_elements_by_name(self, name):
142         buf = []
143         self._r_get_elements_by_name(name, buf)
144         return buf
145
146     def _r_get_elements_by_name(self, name, buf):
147         if self.name == name:
148             buf.append(self)
149         for i in self:
150             i._r_get_elements_by_name(name, buf)
151
152     def get_element_by_id(self, id):
153         if "id" in self.attrs and self.attrs["id"] == id:
154             return self
155         for i in self:
156             e = i.get_element_by_id(id)
157             if e != None:
158                 return e
159         #raise HTMLElementError("Element not found")
160         return None
161
162     def get_elements_by_class(self, cls):
163         buf = []
164         self._r_get_elements_by_class(cls, buf)
165         return buf
166
167     def _r_get_elements_by_class(self, cls, buf):
168         if self.get_attribute("class") == cls:
169             buf.append(self)
170         for i in self:
171             i._r_get_elements_by_class(cls, buf)
172
173     # manipulation functions
174     def append_tag(self, tag, attrs):
175         elem = HTMLElement(HTMLElement.TAG, tag, attrs)
176         self.append(elem)
177
178     def remove_element(self, elem):
179         parent = elem.parent()
180         parent.remove(elem)
181
182     # query functions
183     # TODO: this function is under implementing...
184     def select(self, expr):
185         terms = expr.strip().split()
186         if len(terms) == 0:
187             return []
188         results = self
189         for pat in terms:
190             t = []
191             for elem in results:
192                 t.extend(self._select_pattern(pat, elem))
193             results = t
194         return results
195
196     def _select_pattern(self, pat, elem):
197         results = []
198         if pat[0] == "#":
199             results = [elem.get_element_by_id(pat[1:]),]
200         elif pat[0] == ".":
201             results = elem.get_elements_by_class(pat[1:])
202         return [x for x in results if x]
203
204     def select_1st(self, expr):
205         r = self.select(expr)
206         if len(r) == 0:
207             return None
208         else:
209             return r[0]
210
211     def select_by_name2(self, term1, term2):
212         tbl = self.get_elements_by_name(term1)
213         buf = []
214         for elem in tbl:
215             st = elem.get_elements_by_name(term2)
216             buf.extend(st)
217         return buf
218
219     # is_* functions
220     def is_text(self):
221         return self.type == HTMLElement.TEXT
222
223     def is_tag(self):
224         return self.type == HTMLElement.TAG
225
226     def is_root(self):
227         return self.type == HTMLElement.ROOT
228
229     def is_descendant(self, tagname):
230         p = self.parent()
231         while p != None:
232             if p.name == tagname:
233                 return p
234             p = p.parent()
235         return False
236
237     # mmmh....
238     def trace_back(self, tag):
239         """ regexp string => list"""
240         p = self.parent()
241         rex = re.compile(tag)
242         result = []
243         while p != None:
244             if rex.search(p.name):
245                 result.append(p.name)
246             p = p.parent()
247         return result
248
249
250 class HTMLTreeError(Exception):
251     def __init__(self, msg, lineno, offset):
252         self.msg = msg
253         self.lineno = lineno
254         self.offset = offset
255
256     def __repr__(self):
257         str = "HTML Parse Error: %s , line: %d, char: %d" % (self.msg, self.lineno, self.offset)
258         return str
259     
260
261 class HTMLTree(HTMLParser.HTMLParser):
262     "HTML Tree Builder"
263     USE_VALIDATE = 0x0001
264
265     IGNORE_BLANK = 0x0010
266     TRUNC_BLANK  = 0x0020
267     JOIN_TEXT    = 0x0040
268
269     TRUNC_BR = 0x0100
270     # TODO: check tags not need to close more strict...
271     UNCLOSABLE_TAGS = ["br", "link", "meta", "img", "input"]
272
273     def __init__(self):
274         "Constructor"
275         HTMLParser.HTMLParser.__init__(self)
276
277     def parse(self, data, charset=None, option=0):
278         """
279         Parse given HTML.
280
281         Arguments:
282         data -- HTML to parse
283         charset -- charset of HTML (default: None)
284         option -- option (default: 0, meaning none)
285         
286         """
287
288         self.charset = charset
289         self._htmlroot = HTMLElement(HTMLElement.ROOT)
290         self._cursor = self._htmlroot
291         self._option = option
292         try:
293             self.feed(data)
294         except HTMLParser.HTMLParseError, e:
295             raise HTMLTreeError("HTML parse error: " + e.msg,
296                                 e.lineno, e.offset)
297
298         # if charset is not given, detect charset
299         if self.charset == None:
300             r = self.root()
301             metas = r.get_elements_by_name("meta")
302             for meta in metas:
303                 if meta.attrs.get("http-equiv", None) == "Content-Type":
304                     ctype = meta.attrs.get("content", "")
305                     m = re.search(r"charset=([^;]+)", ctype)
306                     if m:
307                         self.charset = m.group(1)
308                     else:
309                         self.charset = None
310                         
311             if self.charset:
312                 self._htmlroot = HTMLElement(HTMLElement.ROOT)
313                 self._cursor = self._htmlroot
314                 self.feed(data)
315
316         self._finalize()
317
318     def _finalize(self):
319         r = self.root()
320         self._r_finalize(r)
321
322     def _r_finalize(self, elem):
323         if elem.is_text():
324             return
325         
326         l = len(elem)
327         if l > 1:
328             elem[0]._next_elem = elem[1]
329         for i in range(1, l-1):
330             elem[i]._prev_elem = elem[i-1]
331             elem[i]._next_elem = elem[i+1]
332         if l > 1:
333             elem[l-1]._prev_elem = elem[l-2]
334
335         for sub_elem in elem:
336             self._r_finalize(sub_elem)
337
338     def validate(self):
339         r = self.root()
340         self._r_validate(self, e)
341
342     # Handlers
343     def handle_starttag(self, tag, attrs):
344         # some tags treat as start-end tag.
345         if tag in self.UNCLOSABLE_TAGS:
346             return self.handle_startendtag(tag, attrs)
347             
348         elem = HTMLElement(HTMLElement.TAG, tag, attrs)
349
350         if self._option & HTMLTree.USE_VALIDATE > 0:
351             # try validation (experimental)
352             if tag == "li" and self._cursor.name == "li":
353                 self.handle_endtag("li")
354             # end of validation
355
356         elem._parent = self._cursor
357         self._cursor.append(elem)
358         self._cursor = elem
359
360     def handle_endtag(self, tag):
361         # some tags treat as start-end tag.
362         if tag in self.UNCLOSABLE_TAGS:
363             return
364
365         self._cursor = self._cursor.parent()
366
367     def handle_startendtag(self, tag, attrs):
368         elem = HTMLElement(HTMLElement.TAG, tag, attrs)
369         elem._parent = self._cursor
370         self._cursor.append(elem)
371
372     def handle_data(self, data):
373         if self._option & HTMLTree.IGNORE_BLANK > 0:
374             if re.search(r"^\s*$", data):
375                 data = ""
376
377         elem = HTMLElement(HTMLElement.TEXT)
378         elem._parent = self._cursor
379
380         # text encode check and convert.
381         # if charset is given, convert text to unicode type.
382         if self.charset:
383             try:
384                 elem._text = unicode(data, self.charset)
385             except TypeError:
386                 # self.charset is utf-8.
387                 elem._text = data
388         else:
389             # treat as unicode input
390             elem._text = data
391         self._cursor.append(elem)
392
393     # Accessor
394     def root(self):
395         return self._htmlroot