--- src/org/python/modules/cPickle.java (revision 3410) +++ src/org/python/modules/cPickle.java Sun Aug 12 12:54:58 MSD 2007 @@ -13,6 +13,7 @@ package org.python.modules; import java.util.*; +import java.math.BigInteger; import org.python.core.*; import org.python.core.imp; @@ -335,13 +336,13 @@ /** * File format version we write. */ - public static final String format_version = "1.3"; + public static final String format_version = "2.0"; /** * Old format versions we can read. */ public static final String[] compatible_formats = - new String[] { "1.0", "1.1", "1.2" }; + new String[] { "1.0", "1.1", "1.2", "1.3", "2.0" }; /** * Highest protocol version supported. @@ -403,7 +404,22 @@ final static char SETITEMS = 'u'; final static char BINFLOAT = 'G'; + final static char PROTO = 0x80; + final static char NEWOBJ = 0x81; + final static char EXT1 = 0x82; + final static char EXT2 = 0x83; + final static char EXT4 = 0x84; + final static char TUPLE1 = 0x85; + final static char TUPLE2 = 0x86; + final static char TUPLE3 = 0x87; + final static char NEWTRUE = 0x88; + final static char NEWFALSE = 0x89; + final static char LONG1 = 0x8A; + final static char LONG4 = 0x8B; + private static PyDictionary dispatch_table = null; + private static PyDictionary extension_registry = null; + private static PyDictionary inverted_registry = null; private static PyType BuiltinFunctionType = @@ -434,6 +450,8 @@ PyType.fromClass(PyNone.class); private static PyType StringType = PyType.fromClass(PyString.class); + private static PyType UnicodeType = + PyType.fromClass(PyUnicode.class); private static PyType TupleType = PyType.fromClass(PyTuple.class); private static PyType FileType = @@ -444,6 +462,8 @@ private static PyObject dict; + private static final int BATCHSIZE = 1024; + /** * Initialization when module is imported. */ @@ -457,6 +477,8 @@ PyModule copyreg = (PyModule)importModule("copy_reg"); dispatch_table = (PyDictionary)copyreg.__getattr__("dispatch_table"); + extension_registry = (PyDictionary)copyreg.__getattr__("_extension_registry"); + inverted_registry = (PyDictionary)copyreg.__getattr__("_inverted_registry"); PickleError = buildClass("PickleError", Py.Exception, "_PickleError", ""); @@ -545,7 +567,7 @@ * @returns a new Pickler instance. */ public static Pickler Pickler(PyObject file) { - return new Pickler(file, false); + return new Pickler(file, 0); } @@ -554,11 +576,11 @@ * @param file a file-like object, can be a cStringIO.StringIO, * a PyFile or any python object which implements a * write method. - * @param bin when true, the output will be written as binary data. + * @param protocol pickle protocol version (0 - text, 1 - pre-2.3 binary, 2 - 2.3) * @returns a new Pickler instance. */ - public static Pickler Pickler(PyObject file, boolean bin) { - return new Pickler(file, bin); + public static Pickler Pickler(PyObject file, int protocol) { + return new Pickler(file, protocol); } @@ -584,7 +606,7 @@ * @returns a new Unpickler instance. */ public static void dump(PyObject object, PyObject file) { - dump(object, file, false); + dump(object, file, 0); } /** @@ -593,11 +615,11 @@ * @param file a file-like object, can be a cStringIO.StringIO, * a PyFile or any python object which implements a * write method. - * @param bin when true, the output will be written as binary data. + * @param protocol pickle protocol version (0 - text, 1 - pre-2.3 binary, 2 - 2.3) * @returns a new Unpickler instance. */ - public static void dump(PyObject object, PyObject file, boolean bin) { - new Pickler(file, bin).dump(object); + public static void dump(PyObject object, PyObject file, int protocol) { + new Pickler(file, protocol).dump(object); } @@ -607,19 +629,19 @@ * @returns a string representing the pickled object. */ public static String dumps(PyObject object) { - return dumps(object, false); + return dumps(object, 0); } /** * Shorthand function which pickles and returns the string representation. * @param object a data object which should be pickled. - * @param bin when true, the output will be written as binary data. + * @param protocol pickle protocol version (0 - text, 1 - pre-2.3 binary, 2 - 2.3) * @returns a string representing the pickled object. */ - public static String dumps(PyObject object, boolean bin) { + public static String dumps(PyObject object, int protocol) { cStringIO.StringIO file = cStringIO.StringIO(); - dump(object, file, bin); + dump(object, file, protocol); return file.getvalue(); } @@ -784,11 +806,11 @@ /** * The Pickler object * @see cPickle#Pickler(PyObject) - * @see cPickle#Pickler(PyObject,boolean) + * @see cPickle#Pickler(PyObject,int) */ static public class Pickler { private IOFile file; - private boolean bin; + private int protocol; /** * The undocumented attribute fast of the C version of cPickle disables @@ -822,9 +844,9 @@ public PyObject inst_persistent_id = null; - public Pickler(PyObject file, boolean bin) { + public Pickler(PyObject file, int protocol) { this.file = createIOFile(file); - this.bin = bin; + this.protocol = protocol; } @@ -833,6 +855,10 @@ * @param object The object which will be pickled. */ public void dump(PyObject object) { + if (protocol >= 2) { + file.write(PROTO); + file.write((char) protocol); + } save(object); file.write(STOP); file.flush(); @@ -846,7 +872,7 @@ // Save name as in pickle.py but semantics are slightly changed. private void put(int i) { - if (bin) { + if (protocol > 0) { if (i < 256) { file.write(BINPUT); file.write((char)i); @@ -867,7 +893,7 @@ // Same name as in pickle.py but semantics are slightly changed. private void get(int i) { - if (bin) { + if (protocol > 0) { if (i < 256) { file.write(BINGET); file.write((char)i); @@ -907,7 +933,7 @@ PyType t = object.getType(); if (t == TupleType && object.__len__() == 0) { - if (bin) + if (protocol > 0) save_empty_tuple(object); else save_tuple(object); @@ -934,10 +960,15 @@ PyObject tup = null; PyObject reduce = dispatch_table.__finditem__(t); if (reduce == null) { + reduce = object.__findattr__("__reduce_ex__"); + if (reduce != null) { + tup = reduce.__call__(Py.newInteger(protocol)); + } else { - reduce = object.__findattr__("__reduce__"); - if (reduce == null) - throw new PyException(UnpickleableError, object); - tup = reduce.__call__(); + reduce = object.__findattr__("__reduce__"); + if (reduce == null) + throw new PyException(UnpickleableError, object); + tup = reduce.__call__(); + } } else { tup = reduce.__call__(object); } @@ -954,15 +985,17 @@ } int l = tup.__len__(); - if (l != 2 && l != 3) { + if (l < 2 || l > 5) { throw new PyException(PicklingError, "tuple returned by " + reduce.__repr__() + - " must contain only two or three elements"); + " must contain two to five elements"); } PyObject callable = tup.__finditem__(0); PyObject arg_tup = tup.__finditem__(1); PyObject state = (l > 2) ? tup.__finditem__(2) : Py.None; + PyObject listitems = (l > 3) ? tup.__finditem__(3) : Py.None; + PyObject dictitems = (l > 4) ? tup.__finditem__(4) : Py.None; if (!(arg_tup instanceof PyTuple) && arg_tup != Py.None) { throw new PyException(PicklingError, @@ -970,14 +1003,14 @@ reduce.__repr__() + " must be a tuple"); } - save_reduce(callable, arg_tup, state); + save_reduce(callable, arg_tup, state, listitems, dictitems); put(putMemo(d, object)); } final private void save_pers(PyObject pid) { - if (!bin) { + if (protocol == 0) { file.write(PERSID); file.write(pid.toString()); file.write("\n"); @@ -988,11 +1021,29 @@ } final private void save_reduce(PyObject callable, PyObject arg_tup, - PyObject state) + PyObject state, PyObject listitems, PyObject dictitems) { + PyObject callableName = callable.__findattr__("__name__"); + if (protocol >= 2 && callableName != null && "__newobj__".equals(callableName.toString())) { + PyObject cls = arg_tup.__finditem__(0); + if (cls.__findattr__("__new__") == null) + throw new PyException(PicklingError, "args[0] from __newobj__ args has no __new__"); + // TODO: check class + save(cls); + save(arg_tup.__getslice__(Py.newInteger(1), Py.None)); + file.write(NEWOBJ); + } + else { - save(callable); - save(arg_tup); - file.write(REDUCE); + save(callable); + save(arg_tup); + file.write(REDUCE); + } + if (listitems != Py.None) { + batch_appends(listitems); + } + if (dictitems != Py.None) { + batch_setitems(dictitems); + } if (state != Py.None) { save(state); file.write(BUILD); @@ -1007,6 +1058,8 @@ save_none(object); else if (type == StringType) save_string(object); + else if (type == UnicodeType) + save_unicode(object); else if (type == IntType) save_int(object); else if (type == LongType) @@ -1043,7 +1096,7 @@ } final private void save_int(PyObject object) { - if (bin) { + if (protocol > 0) { int l = ((PyInteger)object).getValue(); char i1 = (char)( l & 0xFF); char i2 = (char)((l >>> 8 ) & 0xFF); @@ -1075,21 +1128,56 @@ private void save_bool(PyObject object) { int value = ((PyBoolean) object).getValue(); + if (protocol >= 2) { + file.write(value != 0 ? NEWTRUE : NEWFALSE); + } + else { - file.write(INT); - file.write(value != 0 ? "01" : "00"); - file.write("\n"); - } + file.write(INT); + file.write(value != 0 ? "01" : "00"); + file.write("\n"); + } + } - - final private void save_long(PyObject object) { + private void save_long(PyObject object) { + if (protocol >= 2) { + BigInteger integer = ((PyLong) object).getValue(); + byte[] bytes = integer.toByteArray(); + int l = bytes.length; + if (l < 256) { + file.write(LONG1); + file.write((char) l); + } + else { + file.write(LONG4); + writeInt4(l); + } + for(int i=0; i>> 8 ) & 0xFF); + char i3 = (char)((l >>> 16) & 0xFF); + char i4 = (char)((l >>> 24) & 0xFF); + file.write(i1); + file.write(i2); + file.write(i3); + file.write(i4); + } + final private void save_float(PyObject object) { - if (bin) { + if (protocol > 0) { file.write(BINFLOAT); double value= ((PyFloat) object).getValue(); // It seems that struct.pack('>d', ..) and doubleToLongBits @@ -1115,7 +1203,7 @@ boolean unicode = ((PyString) object).isunicode(); String str = object.toString(); - if (bin) { + if (protocol > 0) { if (unicode) str = codecs.PyUnicode_EncodeUTF8(str, "struct"); int l = str.length(); @@ -1147,21 +1235,52 @@ put(putMemo(get_id(object), object)); } + private void save_unicode(PyObject object) { + if (protocol > 0) { + String str = codecs.PyUnicode_EncodeUTF8(object.toString(), "struct"); + file.write(BINUNICODE); + writeInt4(str.length()); + file.write(str); + } else { + file.write(UNICODE); + file.write(codecs.PyUnicode_EncodeRawUnicodeEscape(object.toString(), + "strict", true)); + file.write("\n"); + } + put(putMemo(get_id(object), object)); + } - final private void save_tuple(PyObject object) { + private void save_tuple(PyObject object) { int d = get_id(object); - file.write(MARK); - int len = object.__len__(); + if (len > 0 && len <= 3 && protocol >= 2) { - for (int i = 0; i < len; i++) - save(object.__finditem__(i)); + for (int i = 0; i < len; i++) + save(object.__finditem__(i)); + int m = getMemoPosition(d, object); + if (m >= 0) { + for (int i = 0; i < len; i++) + file.write(POP); + get(m); + } + else { + char opcode = (char) (TUPLE1 + len - 1); + file.write(opcode); + put(putMemo(d, object)); + } + return; + } + file.write(MARK); + + for (int i = 0; i < len; i++) + save(object.__finditem__(i)); + if (len > 0) { int m = getMemoPosition(d, object); if (m >= 0) { - if (bin) { + if (protocol > 0) { file.write(POP_MARK); get(m); return; @@ -1181,8 +1300,8 @@ file.write(EMPTY_TUPLE); } - final private void save_list(PyObject object) { - if (bin) + private void save_list(PyObject object) { + if (protocol > 0) file.write(EMPTY_LIST); else { file.write(MARK); @@ -1191,24 +1310,37 @@ put(putMemo(get_id(object), object)); - int len = object.__len__(); - boolean using_appends = bin && len > 1; + batch_appends(object); + } - if (using_appends) - file.write(MARK); - - for (int i = 0; i < len; i++) { - save(object.__finditem__(i)); - if (!using_appends) + private void batch_appends(PyObject object) { + PyObject iter = object.__iter__(); + int countInBatch = 0; + PyObject nextObj; + while((nextObj = iter.__iternext__()) != null) { + if (protocol == 0) { + save(nextObj); file.write(APPEND); - } + } - if (using_appends) + else { + if (countInBatch == 0) { + file.write(MARK); + } + countInBatch++; + save(nextObj); + if (countInBatch == BATCHSIZE) { - file.write(APPENDS); + file.write(APPENDS); + countInBatch = 0; - } + } + } + } + if (countInBatch > 0) + file.write(APPENDS); + } - final private void save_dict(PyObject object) { - if (bin) + private void save_dict(PyObject object) { + if (protocol > 0) file.write(EMPTY_DICT); else { file.write(MARK); @@ -1217,10 +1349,14 @@ put(putMemo(get_id(object), object)); + batch_setitems(object); + } + + private void batch_setitems(PyObject object) { PyObject list = object.invoke("keys"); int len = list.__len__(); - boolean using_setitems = (bin && len > 1); + boolean using_setitems = (protocol > 0 && len > 1); if (using_setitems) file.write(MARK); @@ -1233,8 +1369,13 @@ if (!using_setitems) file.write(SETITEM); + else if (i > 0 && i % BATCHSIZE == 0) { + file.write(SETITEMS); + if (len % BATCHSIZE != 0) + file.write(MARK); - } + } - if (using_setitems) + } + if (using_setitems && len % BATCHSIZE != 0) file.write(SETITEMS); } @@ -1255,7 +1396,7 @@ } file.write(MARK); - if (bin) + if (protocol > 0) save(cls); if (args != null) { @@ -1265,7 +1406,7 @@ } int mid = putMemo(get_id(object), object); - if (bin) { + if (protocol > 0) { file.write(OBJ); put(mid); } else { @@ -1303,6 +1444,28 @@ if (module == null || module == Py.None) module = whichmodule(object, name); + if (protocol >= 2) { + PyTuple extKey = new PyTuple(new PyObject[]{module, name}); + PyObject extCode = extension_registry.get(extKey); + if (extCode != Py.None) { + int code = ((PyInteger) extCode).getValue(); + if (code <= 0xFF) { + file.write(EXT1); + file.write((char) code); + } + else if (code <= 0xFFFF) { + file.write(EXT2); + file.write((char) (code & 0xFF)); + file.write((char) (code >> 8)); + } + else { + file.write(EXT4); + writeInt4(code); + } + return; + } + } + file.write(GLOBAL); file.write(module.toString()); file.write("\n"); @@ -1620,6 +1783,18 @@ case SETITEMS: load_setitems(); break; case BUILD: load_build(); break; case MARK: load_mark(); break; + case PROTO: load_proto(); break; + case NEWOBJ: load_newobj(); break; + case EXT1: load_ext(1); break; + case EXT2: load_ext(2); break; + case EXT4: load_ext(4); break; + case TUPLE1: load_small_tuple(1); break; + case TUPLE2: load_small_tuple(2); break; + case TUPLE3: load_small_tuple(3); break; + case NEWTRUE: load_boolean(true); break; + case NEWFALSE: load_boolean(false); break; + case LONG1: load_bin_long(1); break; + case LONG4: load_bin_long(4); break; case STOP: return load_stop(); } @@ -1640,7 +1815,13 @@ throw new PyException(Py.EOFError); } + private void load_proto() { + int proto = file.read(1).charAt(0); + if (proto < 0 || proto > 2) + throw Py.ValueError("unsupported pickle protocol: " + proto); + } + final private void load_persid() { String pid = file.readlineNoNl(); push(persistent_load.__call__(new PyString(pid))); @@ -1663,10 +1844,10 @@ // The following could be abstracted into a common string // -> int/long method. if (line.equals("01")) { - value = Py.newBoolean(true); + value = Py.True; } else if (line.equals("00")) { - value = Py.newBoolean(false); + value = Py.False; } else { try { @@ -1682,14 +1863,21 @@ push(value); } + private void load_boolean(boolean value) { + push(value ? Py.True : Py.False); + } final private void load_binint() { + int x = read_binint(); + push(new PyInteger(x)); + } + + private int read_binint() { String s = file.read(4); - int x = s.charAt(0) | + return s.charAt(0) | (s.charAt(1)<<8) | (s.charAt(2)<<16) | (s.charAt(3)<<24); - push(new PyInteger(x)); } @@ -1699,17 +1887,47 @@ } final private void load_binint2() { - String s = file.read(2); - int val = (s.charAt(1)) << 8 | (s.charAt(0)); + int val = read_binint2(); push(new PyInteger(val)); } + private int read_binint2() { + String s = file.read(2); + return (s.charAt(1)) << 8 | (s.charAt(0)); + } + final private void load_long() { String line = file.readlineNoNl(); push(new PyLong(line.substring(0, line.length()-1))); } + private void load_bin_long(int length) { + int longLength = read_binint(length); + String s = file.read(longLength); + byte[] bytes = new byte[s.length()]; + for(int i=0; i= 128) { + bytes [i] = (byte) (c - 256); + } + else { + bytes [i] = (byte) c; + } + } + BigInteger bigint = new BigInteger(bytes); + push(new PyLong(bigint)); + } + + private int read_binint(int length) { + if (length == 1) + return file.read(1).charAt(0); + else if (length == 2) + return read_binint2(); + else + return read_binint(); + } + final private void load_float() { String line = file.readlineNoNl(); push(new PyFloat(Double.valueOf(line).doubleValue())); @@ -1764,11 +1982,7 @@ final private void load_binstring() { - String d = file.read(4); - int len = d.charAt(0) | - (d.charAt(1)<<8) | - (d.charAt(2)<<16) | - (d.charAt(3)<<24); + int len = read_binint(); push(new PyString(file.read(len))); } @@ -1788,11 +2002,7 @@ } final private void load_binunicode() { - String d = file.read(4); - int len = d.charAt(0) | - (d.charAt(1)<<8) | - (d.charAt(2)<<16) | - (d.charAt(3)<<24); + int len = read_binint(); String line = file.read(len); push(new PyString(codecs.PyUnicode_DecodeUTF8(line, "strict"))); } @@ -1808,6 +2018,14 @@ push(new PyTuple(Py.EmptyObjects)); } + private void load_small_tuple(int length) { + PyObject[] data = new PyObject[length]; + for(int i=length-1; i >= 0; i--) { + data [i] = pop(); + } + push(new PyTuple(data)); + } + final private void load_empty_list() { push(new PyList(Py.EmptyObjects)); } @@ -1905,7 +2123,19 @@ return global; } + private void load_ext(int length) { + int code = read_binint(length); + // TODO: support _extension_cache + PyObject key = inverted_registry.get(Py.newInteger(code)); + if (key == null) { + throw new PyException(Py.ValueError, "unregistered extension code " + code); + } + String module = key.__finditem__(0).toString(); + String name = key.__finditem__(1).toString(); + push(find_class(module, name)); + } + final private void load_reduce() { PyObject arg_tup = pop(); PyObject callable = pop(); @@ -1919,6 +2149,17 @@ push(value); } + private void load_newobj() { + PyObject arg_tup = pop(); + PyObject cls = pop(); + PyObject[] args = new PyObject[arg_tup.__len__() + 1]; + args [0] = cls; + for(int i=1; i