Index: src/org/python/core/__builtin__.java =================================================================== --- src/org/python/core/__builtin__.java (revision 3608) +++ src/org/python/core/__builtin__.java (working copy) @@ -470,24 +470,30 @@ } public static PyObject dir(PyObject o) { - PyList ret = (PyList) o.__dir__(); - ret.sort(); - return ret; + PyObject ret = o.__dir__(); + if (!(ret instanceof PyList)) { + throw Py.TypeError("Expected keys() to be a list"); + } + ((PyList) ret).sort(); + return ret; } public static PyObject dir() { - PyObject l = locals(); - PyList ret; + PyObject l = locals(); + PyObject ret; - if (l instanceof PyStringMap) { - ret = ((PyStringMap) l).keys(); - } else if (l instanceof PyDictionary) { - ret = ((PyDictionary) l).keys(); - } + if (l instanceof PyStringMap) { + ret = ((PyStringMap) l).keys(); + } else if (l instanceof PyDictionary) { + ret = ((PyDictionary) l).keys(); + } - ret = (PyList) l.invoke("keys"); - ret.sort(); - return ret; + ret = l.invoke("keys"); + if (!(ret instanceof PyList)) { + throw Py.TypeError("Expected keys() to be a list"); + } + ((PyList) ret).sort(); + return ret; } public static PyObject divmod(PyObject x, PyObject y) { @@ -498,22 +504,53 @@ return new PyEnumerate(seq); } + private static boolean checkMapping(PyObject o) { + if(o instanceof PyStringMap) { + return true; + } + if(hasattr(o,"__getitem__")) { + return true; + } + return false; + } + public static PyObject eval(PyObject o, PyObject globals, PyObject locals) { - PyCode code; - if (o instanceof PyCode) { - code = (PyCode) o; - } else { - if (o instanceof PyString) { - code = compile(o.toString(), "", "eval"); - } else { - throw Py.TypeError("eval: argument 1 must be string or code object"); - } - } - return Py.runCode(code, locals, globals); + PyCode code; + if (o instanceof PyCode) { + code = (PyCode) o; + } else { + if (o instanceof PyString) { + code = compile(o.toString(), "", "eval"); + } else { + throw Py.TypeError("eval: argument 1 must be string or code object"); + } + } + + if (globals == Py.None) { + globals = null; + } + + if (locals == Py.None) { + locals = null; + } + + if (locals != null && !checkMapping(locals)) { + throw Py.TypeError("locals must be a mapping"); + } + + if (globals != null && + !(globals instanceof PyDictionary || globals instanceof PyStringMap)) { + + throw Py.TypeError(checkMapping(globals) ? + "globals must be a real dict; try eval(expr, {}, mapping)" : + "globals must be a dict"); + } + + return Py.runCode(code, locals, globals); } public static PyObject eval(PyObject o, PyObject globals) { - return eval(o, globals, globals); + return eval(o, globals, globals); } public static PyObject eval(PyObject o) { @@ -545,6 +582,15 @@ throw Py.IOError(e); } } + + if (locals == Py.None) { + locals = null; + } + + if (locals != null && !checkMapping(locals)) { + throw Py.TypeError("locals must be a mapping"); + } + Py.runCode(code, locals, globals); } Index: Lib/test/test_builtin.py =================================================================== --- Lib/test/test_builtin.py (revision 3608) +++ Lib/test/test_builtin.py (working copy) @@ -4,12 +4,12 @@ from test.test_support import fcmp, have_unicode, TESTFN, unlink from sets import Set -import sys, warnings, cStringIO +import sys, warnings, cStringIO, random, UserDict warnings.filterwarnings("ignore", "hex../oct.. of negative int", FutureWarning, __name__) warnings.filterwarnings("ignore", "integer argument expected", DeprecationWarning, "unittest") - +numruns = 0 class Squares: def __init__(self, max): @@ -277,7 +277,89 @@ f.close() execfile(TESTFN) + def test_general_eval(self): + # Tests that general mappings can be used for the locals argument + + class M: + "Test mapping interface versus possible calls from eval()." + def __getitem__(self, key): + if key == 'a': + return 12 + raise KeyError + def keys(self): + return list('xyz') + + m = M() + g = globals() + self.assertEqual(eval('a', g, m), 12) + self.assertRaises(NameError, eval, 'b', g, m) + self.assertEqual(eval('dir()', g, m), list('xyz')) + self.assertEqual(eval('globals()', g, m), g) + self.assertEqual(eval('locals()', g, m), m) + self.assertRaises(TypeError, eval, 'a', m) + class A: + "Non-mapping" + pass + m = A() + self.assertRaises(TypeError, eval, 'a', g, m) + + # Verify that dict subclasses work as well + class D(dict): + def __getitem__(self, key): + if key == 'a': + return 12 + return dict.__getitem__(self, key) + def keys(self): + return list('xyz') + + d = D() + self.assertEqual(eval('a', g, d), 12) + self.assertRaises(NameError, eval, 'b', g, d) + self.assertEqual(eval('dir()', g, d), list('xyz')) + self.assertEqual(eval('globals()', g, d), g) + self.assertEqual(eval('locals()', g, d), d) + + # Verify locals stores (used by list comps) + eval('[locals() for i in (2,3)]', g, d) + eval('[locals() for i in (2,3)]', g, UserDict.UserDict()) + + class SpreadSheet: + "Sample application showing nested, calculated lookups." + _cells = {} + def __setitem__(self, key, formula): + self._cells[key] = formula + def __getitem__(self, key): + return eval(self._cells[key], globals(), self) + + ss = SpreadSheet() + ss['a1'] = '5' + ss['a2'] = 'a1*6' + ss['a3'] = 'a2*7' + self.assertEqual(ss['a3'], 210) + + # Verify that dir() catches a non-list returned by eval + # SF bug #1004669 + class C: + def __getitem__(self, item): + raise KeyError(item) + def keys(self): + return 'a' + self.assertRaises(TypeError, eval, 'dir()', globals(), C()) + + # Done outside of the method test_z to get the correct scope + z = 0 + f = open(TESTFN, 'w') + f.write('z = z+1\n') + f.write('z = z*2\n') + f.close() + execfile(TESTFN) + def test_execfile(self): + global numruns + if numruns: + return + numruns += 1 + globals = {'a': 1, 'b': 2} locals = {'b': 200, 'c': 300} @@ -288,8 +370,29 @@ locals['z'] = 0 execfile(TESTFN, globals, locals) self.assertEqual(locals['z'], 2) + + class M: + "Test mapping interface versus possible calls from execfile()." + def __init__(self): + self.z = 10 + def __getitem__(self, key): + if key == 'z': + return self.z + raise KeyError + def __setitem__(self, key, value): + if key == 'z': + self.z = value + return + raise KeyError + + locals = M() + locals['z'] = 0 + execfile(TESTFN, globals, locals) + self.assertEqual(locals['z'], 2) + + self.assertRaises(TypeError, execfile) + self.assertRaises(TypeError, execfile, TESTFN, {}, ()) unlink(TESTFN) - self.assertRaises(TypeError, execfile) import os self.assertRaises(IOError, execfile, os.curdir) self.assertRaises(IOError, execfile, "I_dont_exist")