--- src/org/python/modules/cPickle.java (revision 3410)
+++ src/org/python/modules/cPickle.java Sun Aug 12 12:54:58 MSD 2007
@@ -13,6 +13,7 @@
package org.python.modules;
import java.util.*;
+import java.math.BigInteger;
import org.python.core.*;
import org.python.core.imp;
@@ -335,13 +336,13 @@
/**
* File format version we write.
*/
- public static final String format_version = "1.3";
+ public static final String format_version = "2.0";
/**
* Old format versions we can read.
*/
public static final String[] compatible_formats =
- new String[] { "1.0", "1.1", "1.2" };
+ new String[] { "1.0", "1.1", "1.2", "1.3", "2.0" };
/**
* Highest protocol version supported.
@@ -403,7 +404,22 @@
final static char SETITEMS = 'u';
final static char BINFLOAT = 'G';
+ final static char PROTO = 0x80;
+ final static char NEWOBJ = 0x81;
+ final static char EXT1 = 0x82;
+ final static char EXT2 = 0x83;
+ final static char EXT4 = 0x84;
+ final static char TUPLE1 = 0x85;
+ final static char TUPLE2 = 0x86;
+ final static char TUPLE3 = 0x87;
+ final static char NEWTRUE = 0x88;
+ final static char NEWFALSE = 0x89;
+ final static char LONG1 = 0x8A;
+ final static char LONG4 = 0x8B;
+
private static PyDictionary dispatch_table = null;
+ private static PyDictionary extension_registry = null;
+ private static PyDictionary inverted_registry = null;
private static PyType BuiltinFunctionType =
@@ -434,6 +450,8 @@
PyType.fromClass(PyNone.class);
private static PyType StringType =
PyType.fromClass(PyString.class);
+ private static PyType UnicodeType =
+ PyType.fromClass(PyUnicode.class);
private static PyType TupleType =
PyType.fromClass(PyTuple.class);
private static PyType FileType =
@@ -444,6 +462,8 @@
private static PyObject dict;
+ private static final int BATCHSIZE = 1024;
+
/**
* Initialization when module is imported.
*/
@@ -457,6 +477,8 @@
PyModule copyreg = (PyModule)importModule("copy_reg");
dispatch_table = (PyDictionary)copyreg.__getattr__("dispatch_table");
+ extension_registry = (PyDictionary)copyreg.__getattr__("_extension_registry");
+ inverted_registry = (PyDictionary)copyreg.__getattr__("_inverted_registry");
PickleError = buildClass("PickleError", Py.Exception,
"_PickleError", "");
@@ -545,7 +567,7 @@
* @returns a new Pickler instance.
*/
public static Pickler Pickler(PyObject file) {
- return new Pickler(file, false);
+ return new Pickler(file, 0);
}
@@ -554,11 +576,11 @@
* @param file a file-like object, can be a cStringIO.StringIO,
* a PyFile or any python object which implements a
* write method.
- * @param bin when true, the output will be written as binary data.
+ * @param protocol pickle protocol version (0 - text, 1 - pre-2.3 binary, 2 - 2.3)
* @returns a new Pickler instance.
*/
- public static Pickler Pickler(PyObject file, boolean bin) {
- return new Pickler(file, bin);
+ public static Pickler Pickler(PyObject file, int protocol) {
+ return new Pickler(file, protocol);
}
@@ -584,7 +606,7 @@
* @returns a new Unpickler instance.
*/
public static void dump(PyObject object, PyObject file) {
- dump(object, file, false);
+ dump(object, file, 0);
}
/**
@@ -593,11 +615,11 @@
* @param file a file-like object, can be a cStringIO.StringIO,
* a PyFile or any python object which implements a
* write method.
- * @param bin when true, the output will be written as binary data.
+ * @param protocol pickle protocol version (0 - text, 1 - pre-2.3 binary, 2 - 2.3)
* @returns a new Unpickler instance.
*/
- public static void dump(PyObject object, PyObject file, boolean bin) {
- new Pickler(file, bin).dump(object);
+ public static void dump(PyObject object, PyObject file, int protocol) {
+ new Pickler(file, protocol).dump(object);
}
@@ -607,19 +629,19 @@
* @returns a string representing the pickled object.
*/
public static String dumps(PyObject object) {
- return dumps(object, false);
+ return dumps(object, 0);
}
/**
* Shorthand function which pickles and returns the string representation.
* @param object a data object which should be pickled.
- * @param bin when true, the output will be written as binary data.
+ * @param protocol pickle protocol version (0 - text, 1 - pre-2.3 binary, 2 - 2.3)
* @returns a string representing the pickled object.
*/
- public static String dumps(PyObject object, boolean bin) {
+ public static String dumps(PyObject object, int protocol) {
cStringIO.StringIO file = cStringIO.StringIO();
- dump(object, file, bin);
+ dump(object, file, protocol);
return file.getvalue();
}
@@ -784,11 +806,11 @@
/**
* The Pickler object
* @see cPickle#Pickler(PyObject)
- * @see cPickle#Pickler(PyObject,boolean)
+ * @see cPickle#Pickler(PyObject,int)
*/
static public class Pickler {
private IOFile file;
- private boolean bin;
+ private int protocol;
/**
* The undocumented attribute fast of the C version of cPickle disables
@@ -822,9 +844,9 @@
public PyObject inst_persistent_id = null;
- public Pickler(PyObject file, boolean bin) {
+ public Pickler(PyObject file, int protocol) {
this.file = createIOFile(file);
- this.bin = bin;
+ this.protocol = protocol;
}
@@ -833,6 +855,10 @@
* @param object The object which will be pickled.
*/
public void dump(PyObject object) {
+ if (protocol >= 2) {
+ file.write(PROTO);
+ file.write((char) protocol);
+ }
save(object);
file.write(STOP);
file.flush();
@@ -846,7 +872,7 @@
// Save name as in pickle.py but semantics are slightly changed.
private void put(int i) {
- if (bin) {
+ if (protocol > 0) {
if (i < 256) {
file.write(BINPUT);
file.write((char)i);
@@ -867,7 +893,7 @@
// Same name as in pickle.py but semantics are slightly changed.
private void get(int i) {
- if (bin) {
+ if (protocol > 0) {
if (i < 256) {
file.write(BINGET);
file.write((char)i);
@@ -907,7 +933,7 @@
PyType t = object.getType();
if (t == TupleType && object.__len__() == 0) {
- if (bin)
+ if (protocol > 0)
save_empty_tuple(object);
else
save_tuple(object);
@@ -934,10 +960,15 @@
PyObject tup = null;
PyObject reduce = dispatch_table.__finditem__(t);
if (reduce == null) {
+ reduce = object.__findattr__("__reduce_ex__");
+ if (reduce != null) {
+ tup = reduce.__call__(Py.newInteger(protocol));
+ } else {
- reduce = object.__findattr__("__reduce__");
- if (reduce == null)
- throw new PyException(UnpickleableError, object);
- tup = reduce.__call__();
+ reduce = object.__findattr__("__reduce__");
+ if (reduce == null)
+ throw new PyException(UnpickleableError, object);
+ tup = reduce.__call__();
+ }
} else {
tup = reduce.__call__(object);
}
@@ -954,15 +985,17 @@
}
int l = tup.__len__();
- if (l != 2 && l != 3) {
+ if (l < 2 || l > 5) {
throw new PyException(PicklingError,
"tuple returned by " + reduce.__repr__() +
- " must contain only two or three elements");
+ " must contain two to five elements");
}
PyObject callable = tup.__finditem__(0);
PyObject arg_tup = tup.__finditem__(1);
PyObject state = (l > 2) ? tup.__finditem__(2) : Py.None;
+ PyObject listitems = (l > 3) ? tup.__finditem__(3) : Py.None;
+ PyObject dictitems = (l > 4) ? tup.__finditem__(4) : Py.None;
if (!(arg_tup instanceof PyTuple) && arg_tup != Py.None) {
throw new PyException(PicklingError,
@@ -970,14 +1003,14 @@
reduce.__repr__() + " must be a tuple");
}
- save_reduce(callable, arg_tup, state);
+ save_reduce(callable, arg_tup, state, listitems, dictitems);
put(putMemo(d, object));
}
final private void save_pers(PyObject pid) {
- if (!bin) {
+ if (protocol == 0) {
file.write(PERSID);
file.write(pid.toString());
file.write("\n");
@@ -988,11 +1021,29 @@
}
final private void save_reduce(PyObject callable, PyObject arg_tup,
- PyObject state)
+ PyObject state, PyObject listitems, PyObject dictitems)
{
+ PyObject callableName = callable.__findattr__("__name__");
+ if (protocol >= 2 && callableName != null && "__newobj__".equals(callableName.toString())) {
+ PyObject cls = arg_tup.__finditem__(0);
+ if (cls.__findattr__("__new__") == null)
+ throw new PyException(PicklingError, "args[0] from __newobj__ args has no __new__");
+ // TODO: check class
+ save(cls);
+ save(arg_tup.__getslice__(Py.newInteger(1), Py.None));
+ file.write(NEWOBJ);
+ }
+ else {
- save(callable);
- save(arg_tup);
- file.write(REDUCE);
+ save(callable);
+ save(arg_tup);
+ file.write(REDUCE);
+ }
+ if (listitems != Py.None) {
+ batch_appends(listitems);
+ }
+ if (dictitems != Py.None) {
+ batch_setitems(dictitems);
+ }
if (state != Py.None) {
save(state);
file.write(BUILD);
@@ -1007,6 +1058,8 @@
save_none(object);
else if (type == StringType)
save_string(object);
+ else if (type == UnicodeType)
+ save_unicode(object);
else if (type == IntType)
save_int(object);
else if (type == LongType)
@@ -1043,7 +1096,7 @@
}
final private void save_int(PyObject object) {
- if (bin) {
+ if (protocol > 0) {
int l = ((PyInteger)object).getValue();
char i1 = (char)( l & 0xFF);
char i2 = (char)((l >>> 8 ) & 0xFF);
@@ -1075,21 +1128,56 @@
private void save_bool(PyObject object) {
int value = ((PyBoolean) object).getValue();
+ if (protocol >= 2) {
+ file.write(value != 0 ? NEWTRUE : NEWFALSE);
+ }
+ else {
- file.write(INT);
- file.write(value != 0 ? "01" : "00");
- file.write("\n");
- }
+ file.write(INT);
+ file.write(value != 0 ? "01" : "00");
+ file.write("\n");
+ }
+ }
-
- final private void save_long(PyObject object) {
+ private void save_long(PyObject object) {
+ if (protocol >= 2) {
+ BigInteger integer = ((PyLong) object).getValue();
+ byte[] bytes = integer.toByteArray();
+ int l = bytes.length;
+ if (l < 256) {
+ file.write(LONG1);
+ file.write((char) l);
+ }
+ else {
+ file.write(LONG4);
+ writeInt4(l);
+ }
+ for(int i=0; i>> 8 ) & 0xFF);
+ char i3 = (char)((l >>> 16) & 0xFF);
+ char i4 = (char)((l >>> 24) & 0xFF);
+ file.write(i1);
+ file.write(i2);
+ file.write(i3);
+ file.write(i4);
+ }
+
final private void save_float(PyObject object) {
- if (bin) {
+ if (protocol > 0) {
file.write(BINFLOAT);
double value= ((PyFloat) object).getValue();
// It seems that struct.pack('>d', ..) and doubleToLongBits
@@ -1115,7 +1203,7 @@
boolean unicode = ((PyString) object).isunicode();
String str = object.toString();
- if (bin) {
+ if (protocol > 0) {
if (unicode)
str = codecs.PyUnicode_EncodeUTF8(str, "struct");
int l = str.length();
@@ -1147,21 +1235,52 @@
put(putMemo(get_id(object), object));
}
+ private void save_unicode(PyObject object) {
+ if (protocol > 0) {
+ String str = codecs.PyUnicode_EncodeUTF8(object.toString(), "struct");
+ file.write(BINUNICODE);
+ writeInt4(str.length());
+ file.write(str);
+ } else {
+ file.write(UNICODE);
+ file.write(codecs.PyUnicode_EncodeRawUnicodeEscape(object.toString(),
+ "strict", true));
+ file.write("\n");
+ }
+ put(putMemo(get_id(object), object));
+ }
- final private void save_tuple(PyObject object) {
+ private void save_tuple(PyObject object) {
int d = get_id(object);
- file.write(MARK);
-
int len = object.__len__();
+ if (len > 0 && len <= 3 && protocol >= 2) {
- for (int i = 0; i < len; i++)
- save(object.__finditem__(i));
+ for (int i = 0; i < len; i++)
+ save(object.__finditem__(i));
+ int m = getMemoPosition(d, object);
+ if (m >= 0) {
+ for (int i = 0; i < len; i++)
+ file.write(POP);
+ get(m);
+ }
+ else {
+ char opcode = (char) (TUPLE1 + len - 1);
+ file.write(opcode);
+ put(putMemo(d, object));
+ }
+ return;
+ }
+ file.write(MARK);
+
+ for (int i = 0; i < len; i++)
+ save(object.__finditem__(i));
+
if (len > 0) {
int m = getMemoPosition(d, object);
if (m >= 0) {
- if (bin) {
+ if (protocol > 0) {
file.write(POP_MARK);
get(m);
return;
@@ -1181,8 +1300,8 @@
file.write(EMPTY_TUPLE);
}
- final private void save_list(PyObject object) {
- if (bin)
+ private void save_list(PyObject object) {
+ if (protocol > 0)
file.write(EMPTY_LIST);
else {
file.write(MARK);
@@ -1191,24 +1310,37 @@
put(putMemo(get_id(object), object));
- int len = object.__len__();
- boolean using_appends = bin && len > 1;
+ batch_appends(object);
+ }
- if (using_appends)
- file.write(MARK);
-
- for (int i = 0; i < len; i++) {
- save(object.__finditem__(i));
- if (!using_appends)
+ private void batch_appends(PyObject object) {
+ PyObject iter = object.__iter__();
+ int countInBatch = 0;
+ PyObject nextObj;
+ while((nextObj = iter.__iternext__()) != null) {
+ if (protocol == 0) {
+ save(nextObj);
file.write(APPEND);
- }
+ }
- if (using_appends)
+ else {
+ if (countInBatch == 0) {
+ file.write(MARK);
+ }
+ countInBatch++;
+ save(nextObj);
+ if (countInBatch == BATCHSIZE) {
- file.write(APPENDS);
+ file.write(APPENDS);
+ countInBatch = 0;
- }
+ }
+ }
+ }
+ if (countInBatch > 0)
+ file.write(APPENDS);
+ }
- final private void save_dict(PyObject object) {
- if (bin)
+ private void save_dict(PyObject object) {
+ if (protocol > 0)
file.write(EMPTY_DICT);
else {
file.write(MARK);
@@ -1217,10 +1349,14 @@
put(putMemo(get_id(object), object));
+ batch_setitems(object);
+ }
+
+ private void batch_setitems(PyObject object) {
PyObject list = object.invoke("keys");
int len = list.__len__();
- boolean using_setitems = (bin && len > 1);
+ boolean using_setitems = (protocol > 0 && len > 1);
if (using_setitems)
file.write(MARK);
@@ -1233,8 +1369,13 @@
if (!using_setitems)
file.write(SETITEM);
+ else if (i > 0 && i % BATCHSIZE == 0) {
+ file.write(SETITEMS);
+ if (len % BATCHSIZE != 0)
+ file.write(MARK);
- }
+ }
- if (using_setitems)
+ }
+ if (using_setitems && len % BATCHSIZE != 0)
file.write(SETITEMS);
}
@@ -1255,7 +1396,7 @@
}
file.write(MARK);
- if (bin)
+ if (protocol > 0)
save(cls);
if (args != null) {
@@ -1265,7 +1406,7 @@
}
int mid = putMemo(get_id(object), object);
- if (bin) {
+ if (protocol > 0) {
file.write(OBJ);
put(mid);
} else {
@@ -1303,6 +1444,28 @@
if (module == null || module == Py.None)
module = whichmodule(object, name);
+ if (protocol >= 2) {
+ PyTuple extKey = new PyTuple(new PyObject[]{module, name});
+ PyObject extCode = extension_registry.get(extKey);
+ if (extCode != Py.None) {
+ int code = ((PyInteger) extCode).getValue();
+ if (code <= 0xFF) {
+ file.write(EXT1);
+ file.write((char) code);
+ }
+ else if (code <= 0xFFFF) {
+ file.write(EXT2);
+ file.write((char) (code & 0xFF));
+ file.write((char) (code >> 8));
+ }
+ else {
+ file.write(EXT4);
+ writeInt4(code);
+ }
+ return;
+ }
+ }
+
file.write(GLOBAL);
file.write(module.toString());
file.write("\n");
@@ -1620,6 +1783,18 @@
case SETITEMS: load_setitems(); break;
case BUILD: load_build(); break;
case MARK: load_mark(); break;
+ case PROTO: load_proto(); break;
+ case NEWOBJ: load_newobj(); break;
+ case EXT1: load_ext(1); break;
+ case EXT2: load_ext(2); break;
+ case EXT4: load_ext(4); break;
+ case TUPLE1: load_small_tuple(1); break;
+ case TUPLE2: load_small_tuple(2); break;
+ case TUPLE3: load_small_tuple(3); break;
+ case NEWTRUE: load_boolean(true); break;
+ case NEWFALSE: load_boolean(false); break;
+ case LONG1: load_bin_long(1); break;
+ case LONG4: load_bin_long(4); break;
case STOP:
return load_stop();
}
@@ -1640,7 +1815,13 @@
throw new PyException(Py.EOFError);
}
+ private void load_proto() {
+ int proto = file.read(1).charAt(0);
+ if (proto < 0 || proto > 2)
+ throw Py.ValueError("unsupported pickle protocol: " + proto);
+ }
+
final private void load_persid() {
String pid = file.readlineNoNl();
push(persistent_load.__call__(new PyString(pid)));
@@ -1663,10 +1844,10 @@
// The following could be abstracted into a common string
// -> int/long method.
if (line.equals("01")) {
- value = Py.newBoolean(true);
+ value = Py.True;
}
else if (line.equals("00")) {
- value = Py.newBoolean(false);
+ value = Py.False;
}
else {
try {
@@ -1682,14 +1863,21 @@
push(value);
}
+ private void load_boolean(boolean value) {
+ push(value ? Py.True : Py.False);
+ }
final private void load_binint() {
+ int x = read_binint();
+ push(new PyInteger(x));
+ }
+
+ private int read_binint() {
String s = file.read(4);
- int x = s.charAt(0) |
+ return s.charAt(0) |
(s.charAt(1)<<8) |
(s.charAt(2)<<16) |
(s.charAt(3)<<24);
- push(new PyInteger(x));
}
@@ -1699,17 +1887,47 @@
}
final private void load_binint2() {
- String s = file.read(2);
- int val = (s.charAt(1)) << 8 | (s.charAt(0));
+ int val = read_binint2();
push(new PyInteger(val));
}
+ private int read_binint2() {
+ String s = file.read(2);
+ return (s.charAt(1)) << 8 | (s.charAt(0));
+ }
+
final private void load_long() {
String line = file.readlineNoNl();
push(new PyLong(line.substring(0, line.length()-1)));
}
+ private void load_bin_long(int length) {
+ int longLength = read_binint(length);
+ String s = file.read(longLength);
+ byte[] bytes = new byte[s.length()];
+ for(int i=0; i= 128) {
+ bytes [i] = (byte) (c - 256);
+ }
+ else {
+ bytes [i] = (byte) c;
+ }
+ }
+ BigInteger bigint = new BigInteger(bytes);
+ push(new PyLong(bigint));
+ }
+
+ private int read_binint(int length) {
+ if (length == 1)
+ return file.read(1).charAt(0);
+ else if (length == 2)
+ return read_binint2();
+ else
+ return read_binint();
+ }
+
final private void load_float() {
String line = file.readlineNoNl();
push(new PyFloat(Double.valueOf(line).doubleValue()));
@@ -1764,11 +1982,7 @@
final private void load_binstring() {
- String d = file.read(4);
- int len = d.charAt(0) |
- (d.charAt(1)<<8) |
- (d.charAt(2)<<16) |
- (d.charAt(3)<<24);
+ int len = read_binint();
push(new PyString(file.read(len)));
}
@@ -1788,11 +2002,7 @@
}
final private void load_binunicode() {
- String d = file.read(4);
- int len = d.charAt(0) |
- (d.charAt(1)<<8) |
- (d.charAt(2)<<16) |
- (d.charAt(3)<<24);
+ int len = read_binint();
String line = file.read(len);
push(new PyString(codecs.PyUnicode_DecodeUTF8(line, "strict")));
}
@@ -1808,6 +2018,14 @@
push(new PyTuple(Py.EmptyObjects));
}
+ private void load_small_tuple(int length) {
+ PyObject[] data = new PyObject[length];
+ for(int i=length-1; i >= 0; i--) {
+ data [i] = pop();
+ }
+ push(new PyTuple(data));
+ }
+
final private void load_empty_list() {
push(new PyList(Py.EmptyObjects));
}
@@ -1905,7 +2123,19 @@
return global;
}
+ private void load_ext(int length) {
+ int code = read_binint(length);
+ // TODO: support _extension_cache
+ PyObject key = inverted_registry.get(Py.newInteger(code));
+ if (key == null) {
+ throw new PyException(Py.ValueError, "unregistered extension code " + code);
+ }
+ String module = key.__finditem__(0).toString();
+ String name = key.__finditem__(1).toString();
+ push(find_class(module, name));
+ }
+
final private void load_reduce() {
PyObject arg_tup = pop();
PyObject callable = pop();
@@ -1919,6 +2149,17 @@
push(value);
}
+ private void load_newobj() {
+ PyObject arg_tup = pop();
+ PyObject cls = pop();
+ PyObject[] args = new PyObject[arg_tup.__len__() + 1];
+ args [0] = cls;
+ for(int i=1; i