Changes to 3 files · Browse files at 79b760a98784 Showing diff from parent 2782f6edb956 Diff from another changeset...
@@ -378,7 +378,7 @@ path = self._get_shafile_path(sha)
try:
return ShaFile.from_file(path)
- except OSError, e:
+ except (OSError, IOError), e:
if e.errno == errno.ENOENT:
return None
raise
|
|
|
@@ -99,9 +99,10 @@ """Get the object class corresponding to the given type.
:param type: Either a type name string or a numeric type.
- :return: The ShaFile subclass corresponding to the given type.
+ :return: The ShaFile subclass corresponding to the given type, or None if
+ type is not a valid type name/number.
"""
- return _TYPE_MAP[type]
+ return _TYPE_MAP.get(type, None)
def check_hexsha(hex, error_msg):
@@ -124,32 +125,40 @@class ShaFile(object):
"""A git SHA file."""
- @classmethod
- def _parse_legacy_object(cls, map):
- """Parse a legacy object, creating it and setting object._text"""
- text = _decompress(map)
- object = None
- for cls in OBJECT_CLASSES:
- if text.startswith(cls.type_name):
- object = cls()
- text = text[len(cls.type_name):]
- break
- assert object is not None, "%s is not a known object type" % text[:9]
- assert text[0] == ' ', "%s is not a space" % text[0]
- text = text[1:]
- size = 0
- i = 0
- while text[0] >= '0' and text[0] <= '9':
- if i > 0 and size == 0:
- raise AssertionError("Size is not in canonical format")
- size = (size * 10) + int(text[0])
- text = text[1:]
- i += 1
- object._size = size
- assert text[0] == "\0", "Size not followed by null"
- text = text[1:]
- object.set_raw_string(text)
- return object
+ @staticmethod
+ def _parse_legacy_object_header(magic, f):
+ """Parse a legacy object, creating it but not reading the file."""
+ bufsize = 1024
+ decomp = zlib.decompressobj()
+ header = decomp.decompress(magic)
+ start = 0
+ end = -1
+ while end < 0:
+ header += decomp.decompress(f.read(bufsize))
+ end = header.find("\0", start)
+ start = len(header)
+ header = header[:end]
+ type_name, size = header.split(" ", 1)
+ size = int(size) # sanity check
+ obj_class = object_class(type_name)
+ if not obj_class:
+ raise ObjectFormatException("Not a known type: %s" % type_name)
+ obj = obj_class()
+ obj._filename = f.name
+ return obj
+
+ def _parse_legacy_object(self, f):
+ """Parse a legacy object, setting the raw string."""
+ size = os.path.getsize(f.name)
+ map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
+ try:
+ text = _decompress(map)
+ finally:
+ map.close()
+ header_end = text.find('\0')
+ if header_end < 0:
+ raise ObjectFormatException("Invalid object header")
+ self.set_raw_string(text[header_end+1:])
def as_legacy_object_chunks(self):
compobj = zlib.compressobj()
@@ -162,9 +171,10 @@ return "".join(self.as_legacy_object_chunks())
def as_raw_chunks(self):
- if self._needs_serialization:
+ if self._needs_parsing:
+ self._ensure_parsed()
+ else:
self._chunked_text = self._serialize()
- self._needs_serialization = False
return self._chunked_text
def as_raw_string(self):
@@ -181,6 +191,9 @@
def _ensure_parsed(self):
if self._needs_parsing:
+ if not self._chunked_text:
+ assert self._filename, "ShaFile needs either text or filename"
+ self._parse_file()
self._deserialize(self._chunked_text)
self._needs_parsing = False
@@ -195,35 +208,55 @@ self._needs_parsing = True
self._needs_serialization = False
+ @staticmethod
+ def _parse_object_header(magic, f):
+ """Parse a new style object, creating it but not reading the file."""
+ num_type = (ord(magic[0]) >> 4) & 7
+ obj_class = object_class(num_type)
+ if not obj_class:
+ raise ObjectFormatError("Not a known type: %d" % num_type)
+ obj = obj_class()
+ obj._filename = f.name
+ return obj
+
+ def _parse_object(self, f):
+ """Parse a new style object, setting self._text."""
+ size = os.path.getsize(f.name)
+ map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
+ try:
+ # skip type and size; type must have already been determined, and we
+ # trust zlib to fail if it's otherwise corrupted
+ byte = ord(map[0])
+ used = 1
+ while (byte & 0x80) != 0:
+ byte = ord(map[used])
+ used += 1
+ raw = map[used:]
+ self.set_raw_string(_decompress(raw))
+ finally:
+ map.close()
+
@classmethod
- def _parse_object(cls, map):
- """Parse a new style object , creating it and setting object._text"""
- used = 0
- byte = ord(map[used])
- used += 1
- type_num = (byte >> 4) & 7
- try:
- object = object_class(type_num)()
- except KeyError:
- raise AssertionError("Not a known type: %d" % type_num)
- while (byte & 0x80) != 0:
- byte = ord(map[used])
- used += 1
- raw = map[used:]
- object.set_raw_string(_decompress(raw))
- return object
+ def _is_legacy_object(cls, magic):
+ b0, b1 = map(ord, magic)
+ word = (b0 << 8) + b1
+ return b0 == 0x78 and (word % 31) == 0
@classmethod
- def _parse_file(cls, map):
- word = (ord(map[0]) << 8) + ord(map[1])
- if ord(map[0]) == 0x78 and (word % 31) == 0:
- return cls._parse_legacy_object(map)
+ def _parse_file_header(cls, f):
+ magic = f.read(2)
+ if cls._is_legacy_object(magic):
+ return cls._parse_legacy_object_header(magic, f)
else:
- return cls._parse_object(map)
+ return cls._parse_object_header(magic, f)
def __init__(self):
"""Don't call this directly"""
self._sha = None
+ self._filename = None
+ self._chunked_text = []
+ self._needs_parsing = False
+ self._needs_serialization = True
def _deserialize(self, chunks):
raise NotImplementedError(self._deserialize)
@@ -231,15 +264,29 @@ def _serialize(self):
raise NotImplementedError(self._serialize)
+ def _parse_file(self):
+ f = GitFile(self._filename, 'rb')
+ try:
+ magic = f.read(2)
+ if self._is_legacy_object(magic):
+ self._parse_legacy_object(f)
+ else:
+ self._parse_object(f)
+ finally:
+ f.close()
+
@classmethod
def from_file(cls, filename):
- """Get the contents of a SHA file on disk"""
- size = os.path.getsize(filename)
+ """Get the contents of a SHA file on disk."""
f = GitFile(filename, 'rb')
try:
- map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
- shafile = cls._parse_file(map)
- return shafile
+ try:
+ obj = cls._parse_file_header(f)
+ obj._needs_parsing = True
+ obj._needs_serialization = True
+ return obj
+ except (IndexError, ValueError), e:
+ raise ObjectFormatException("invalid object header")
finally:
f.close()
@@ -267,7 +314,7 @@
@classmethod
def from_string(cls, string):
- """Create a blob from a string."""
+ """Create a ShaFile from a string."""
obj = cls()
obj.set_raw_string(string)
return obj
@@ -367,13 +414,23 @@ self.set_raw_string(data)
data = property(_get_data, _set_data,
- "The text contained within the blob object.")
+ "The text contained within the blob object.")
def _get_chunked(self):
+ self._ensure_parsed()
return self._chunked_text
def _set_chunked(self, chunks):
self._chunked_text = chunks
+
+ def _serialize(self):
+ if not self._chunked_text:
+ self._ensure_parsed()
+ self._needs_serialization = False
+ return self._chunked_text
+
+ def _deserialize(self, chunks):
+ return "".join(chunks)
chunked = property(_get_chunked, _set_chunked,
"The text within the blob object, as chunks (not necessarily lines).")
@@ -424,8 +481,6 @@
def __init__(self):
super(Tag, self).__init__()
- self._needs_parsing = False
- self._needs_serialization = True
self._tag_timezone_neg_utc = False
@classmethod
@@ -434,13 +489,6 @@ if not isinstance(tag, cls):
raise NotTagError(filename)
return tag
-
- @classmethod
- def from_string(cls, string):
- """Create a blob from a string."""
- shafile = cls()
- shafile.set_raw_string(string)
- return shafile
def check(self):
"""Check this object for internal consistency.
@@ -600,8 +648,6 @@ def __init__(self):
super(Tree, self).__init__()
self._entries = {}
- self._needs_parsing = False
- self._needs_serialization = True
@classmethod
def from_file(cls, filename):
@@ -668,7 +714,6 @@ # TODO: list comprehension is for efficiency in the common (small) case;
# if memory efficiency in the large case is a concern, use a genexp.
self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries])
- self._needs_parsing = False
def check(self):
"""Check this object for internal consistency.
@@ -746,8 +791,6 @@ super(Commit, self).__init__()
self._parents = []
self._encoding = None
- self._needs_parsing = False
- self._needs_serialization = True
self._extra = {}
self._author_timezone_neg_utc = False
self._commit_timezone_neg_utc = False
|
|
|
@@ -212,11 +212,13 @@
class ShaFileCheckTests(unittest.TestCase):
- def assertCheckFails(self, obj, data):
+ def assertCheckFails(self, cls, data):
+ obj = cls()
obj.set_raw_string(data)
self.assertRaises(ObjectFormatException, obj.check)
- def assertCheckSucceeds(self, obj, data):
+ def assertCheckSucceeds(self, cls, data):
+ obj = cls()
obj.set_raw_string(data)
try:
obj.check()
@@ -343,22 +345,22 @@ self.assertEquals('UTF-8', c.encoding)
def test_check(self):
- self.assertCheckSucceeds(Commit(), self.make_commit_text())
- self.assertCheckSucceeds(Commit(), self.make_commit_text(parents=None))
- self.assertCheckSucceeds(Commit(),
+ self.assertCheckSucceeds(Commit, self.make_commit_text())
+ self.assertCheckSucceeds(Commit, self.make_commit_text(parents=None))
+ self.assertCheckSucceeds(Commit,
self.make_commit_text(encoding='UTF-8'))
- self.assertCheckFails(Commit(), self.make_commit_text(tree='xxx'))
- self.assertCheckFails(Commit(), self.make_commit_text(
+ self.assertCheckFails(Commit, self.make_commit_text(tree='xxx'))
+ self.assertCheckFails(Commit, self.make_commit_text(
parents=[a_sha, 'xxx']))
bad_committer = "some guy without an email address 1174773719 +0000"
- self.assertCheckFails(Commit(),
+ self.assertCheckFails(Commit,
self.make_commit_text(committer=bad_committer))
- self.assertCheckFails(Commit(),
+ self.assertCheckFails(Commit,
self.make_commit_text(author=bad_committer))
- self.assertCheckFails(Commit(), self.make_commit_text(author=None))
- self.assertCheckFails(Commit(), self.make_commit_text(committer=None))
- self.assertCheckFails(Commit(), self.make_commit_text(
+ self.assertCheckFails(Commit, self.make_commit_text(author=None))
+ self.assertCheckFails(Commit, self.make_commit_text(committer=None))
+ self.assertCheckFails(Commit, self.make_commit_text(
author=None, committer=None))
def test_check_duplicates(self):
@@ -369,9 +371,9 @@ text = '\n'.join(lines)
if lines[i].startswith('parent'):
# duplicate parents are ok for now
- self.assertCheckSucceeds(Commit(), text)
+ self.assertCheckSucceeds(Commit, text)
else:
- self.assertCheckFails(Commit(), text)
+ self.assertCheckFails(Commit, text)
def test_check_order(self):
lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
@@ -382,9 +384,9 @@ perm = list(perm)
text = '\n'.join(perm + rest)
if perm == headers:
- self.assertCheckSucceeds(Commit(), text)
+ self.assertCheckSucceeds(Commit, text)
else:
- self.assertCheckFails(Commit(), text)
+ self.assertCheckFails(Commit, text)
class TreeTests(ShaFileCheckTests):
@@ -406,6 +408,7 @@ def _do_test_parse_tree(self, parse_tree):
o = Tree.from_file(os.path.join(os.path.dirname(__file__), 'data',
'trees', tree_sha))
+ o._parse_file()
self.assertEquals([('a', 0100644, a_sha), ('b', 0100644, b_sha)],
list(parse_tree(o.as_raw_string())))
@@ -418,7 +421,7 @@ self._do_test_parse_tree(parse_tree)
def test_check(self):
- t = Tree()
+ t = Tree
sha = hex_to_sha(a_sha)
# filenames
@@ -530,26 +533,26 @@ self.assertEquals("v2.6.22-rc7", x.name)
def test_check(self):
- self.assertCheckSucceeds(Tag(), self.make_tag_text())
- self.assertCheckFails(Tag(), self.make_tag_text(object_sha=None))
- self.assertCheckFails(Tag(), self.make_tag_text(object_type_name=None))
- self.assertCheckFails(Tag(), self.make_tag_text(name=None))
- self.assertCheckFails(Tag(), self.make_tag_text(name=''))
- self.assertCheckFails(Tag(), self.make_tag_text(
+ self.assertCheckSucceeds(Tag, self.make_tag_text())
+ self.assertCheckFails(Tag, self.make_tag_text(object_sha=None))
+ self.assertCheckFails(Tag, self.make_tag_text(object_type_name=None))
+ self.assertCheckFails(Tag, self.make_tag_text(name=None))
+ self.assertCheckFails(Tag, self.make_tag_text(name=''))
+ self.assertCheckFails(Tag, self.make_tag_text(
object_type_name="foobar"))
- self.assertCheckFails(Tag(), self.make_tag_text(
+ self.assertCheckFails(Tag, self.make_tag_text(
tagger="some guy without an email address 1183319674 -0700"))
- self.assertCheckFails(Tag(), self.make_tag_text(
+ self.assertCheckFails(Tag, self.make_tag_text(
tagger=("Linus Torvalds <torvalds@woody.linux-foundation.org> "
"Sun 7 Jul 2007 12:54:34 +0700")))
- self.assertCheckFails(Tag(), self.make_tag_text(object_sha="xxx"))
+ self.assertCheckFails(Tag, self.make_tag_text(object_sha="xxx"))
def test_check_duplicates(self):
# duplicate each of the header fields
for i in xrange(4):
lines = self.make_tag_lines()
lines.insert(i, lines[i])
- self.assertCheckFails(Tag(), '\n'.join(lines))
+ self.assertCheckFails(Tag, '\n'.join(lines))
def test_check_order(self):
lines = self.make_tag_lines()
@@ -560,9 +563,9 @@ perm = list(perm)
text = '\n'.join(perm + rest)
if perm == headers:
- self.assertCheckSucceeds(Tag(), text)
+ self.assertCheckSucceeds(Tag, text)
else:
- self.assertCheckFails(Tag(), text)
+ self.assertCheckFails(Tag, text)
class CheckTests(unittest.TestCase):
|
Loading...