asdl_c.py 39.3 KB
Newer Older
Jeremy Hylton's avatar
Jeremy Hylton committed
1 2 3 4 5 6
#! /usr/bin/env python
"""Generate C code from an ASDL description."""

# TO DO
# handle fields that have a type but no name

7
import os, sys
Jeremy Hylton's avatar
Jeremy Hylton committed
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

import asdl

TABSIZE = 8
MAX_COL = 80

def get_c_type(name):
    """Return a string for the C name of the type.

    This function special cases the default types provided by asdl:
    identifier, string, int, bool.
    """
    # XXX ack!  need to figure out where Id is useful and where string
    if isinstance(name, asdl.Id):
        name = name.value
    if name in asdl.builtin_types:
        return name
    else:
        return "%s_ty" % name

def reflow_lines(s, depth):
    """Reflow the line s indented depth tabs.

    Return a sequence of lines where no line extends beyond MAX_COL
    when properly indented.  The first line is properly indented based
    exclusively on depth * TABSIZE.  All following lines -- these are
    the reflowed lines generated by this function -- start at the same
    column as the first character beyond the opening { in the first
    line.
    """
    size = MAX_COL - depth * TABSIZE
    if len(s) < size:
        return [s]

    lines = []
    cur = s
    padding = ""
    while len(cur) > size:
        i = cur.rfind(' ', 0, size)
        # XXX this should be fixed for real
        if i == -1 and 'GeneratorExp' in cur:
            i = size + 3
50
        assert i != -1, "Impossible line %d to reflow: %r" % (size, s)
Jeremy Hylton's avatar
Jeremy Hylton committed
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        lines.append(padding + cur[:i])
        if len(lines) == 1:
            # find new size based on brace
            j = cur.find('{', 0, i)
            if j >= 0:
                j += 2 # account for the brace and the space after it
                size -= j
                padding = " " * j
            else:
                j = cur.find('(', 0, i)
                if j >= 0:
                    j += 1 # account for the paren (no space after it)
                    size -= j
                    padding = " " * j
        cur = cur[i+1:]
    else:
        lines.append(padding + cur)
    return lines

def is_simple(sum):
    """Return True if a sum is a simple.

    A sum is simple if its types have no fields, e.g.
    unaryop = Invert | Not | UAdd | USub
    """
    for t in sum.types:
        if t.fields:
            return False
    return True

81

Jeremy Hylton's avatar
Jeremy Hylton committed
82 83 84 85 86 87 88
class EmitVisitor(asdl.VisitorBase):
    """Visit that emits lines"""

    def __init__(self, file):
        self.file = file
        super(EmitVisitor, self).__init__()

89
    def emit(self, s, depth, reflow=True):
Jeremy Hylton's avatar
Jeremy Hylton committed
90 91 92 93 94 95 96 97 98
        # XXX reflow long lines?
        if reflow:
            lines = reflow_lines(s, depth)
        else:
            lines = [s]
        for line in lines:
            line = (" " * TABSIZE * depth) + line + "\n"
            self.file.write(line)

99

Jeremy Hylton's avatar
Jeremy Hylton committed
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
class TypeDefVisitor(EmitVisitor):
    def visitModule(self, mod):
        for dfn in mod.dfns:
            self.visit(dfn)

    def visitType(self, type, depth=0):
        self.visit(type.value, type.name, depth)

    def visitSum(self, sum, name, depth):
        if is_simple(sum):
            self.simple_sum(sum, name, depth)
        else:
            self.sum_with_constructors(sum, name, depth)

    def simple_sum(self, sum, name, depth):
        enum = []
        for i in range(len(sum.types)):
            type = sum.types[i]
            enum.append("%s=%d" % (type.name, i + 1))
        enums = ", ".join(enum)
        ctype = get_c_type(name)
        s = "typedef enum _%s { %s } %s;" % (name, enums, ctype)
        self.emit(s, depth)
        self.emit("", depth)

    def sum_with_constructors(self, sum, name, depth):
        ctype = get_c_type(name)
        s = "typedef struct _%(name)s *%(ctype)s;" % locals()
        self.emit(s, depth)
        self.emit("", depth)

    def visitProduct(self, product, name, depth):
        ctype = get_c_type(name)
        s = "typedef struct _%(name)s *%(ctype)s;" % locals()
        self.emit(s, depth)
        self.emit("", depth)

137

Jeremy Hylton's avatar
Jeremy Hylton committed
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
class StructVisitor(EmitVisitor):
    """Visitor to generate typdefs for AST."""

    def visitModule(self, mod):
        for dfn in mod.dfns:
            self.visit(dfn)

    def visitType(self, type, depth=0):
        self.visit(type.value, type.name, depth)

    def visitSum(self, sum, name, depth):
        if not is_simple(sum):
            self.sum_with_constructors(sum, name, depth)

    def sum_with_constructors(self, sum, name, depth):
        def emit(s, depth=depth):
            self.emit(s % sys._getframe(1).f_locals, depth)
        enum = []
        for i in range(len(sum.types)):
            type = sum.types[i]
            enum.append("%s_kind=%d" % (type.name, i + 1))

160 161
        emit("enum _%(name)s_kind {" + ", ".join(enum) + "};")

Jeremy Hylton's avatar
Jeremy Hylton committed
162
        emit("struct _%(name)s {")
163
        emit("enum _%(name)s_kind kind;", depth + 1)
Jeremy Hylton's avatar
Jeremy Hylton committed
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
        emit("union {", depth + 1)
        for t in sum.types:
            self.visit(t, depth + 2)
        emit("} v;", depth + 1)
        for field in sum.attributes:
            # rudimentary attribute handling
            type = str(field.type)
            assert type in asdl.builtin_types, type
            emit("%s %s;" % (type, field.name), depth + 1);
        emit("};")
        emit("")

    def visitConstructor(self, cons, depth):
        if cons.fields:
            self.emit("struct {", depth)
            for f in cons.fields:
                self.visit(f, depth + 1)
            self.emit("} %s;" % cons.name, depth)
            self.emit("", depth)
        else:
            # XXX not sure what I want here, nothing is probably fine
            pass

    def visitField(self, field, depth):
        # XXX need to lookup field.type, because it might be something
        # like a builtin...
        ctype = get_c_type(field.type)
        name = field.name
        if field.seq:
193 194 195 196
            if field.type.value in ('cmpop',):
                self.emit("asdl_int_seq *%(name)s;" % locals(), depth)
            else:
                self.emit("asdl_seq *%(name)s;" % locals(), depth)
Jeremy Hylton's avatar
Jeremy Hylton committed
197 198 199 200 201 202 203 204 205 206
        else:
            self.emit("%(ctype)s %(name)s;" % locals(), depth)

    def visitProduct(self, product, name, depth):
        self.emit("struct _%(name)s {" % locals(), depth)
        for f in product.fields:
            self.visit(f, depth + 1)
        self.emit("};", depth)
        self.emit("", depth)

207

Jeremy Hylton's avatar
Jeremy Hylton committed
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
class PrototypeVisitor(EmitVisitor):
    """Generate function prototypes for the .h file"""

    def visitModule(self, mod):
        for dfn in mod.dfns:
            self.visit(dfn)

    def visitType(self, type):
        self.visit(type.value, type.name)

    def visitSum(self, sum, name):
        if is_simple(sum):
            pass # XXX
        else:
            for t in sum.types:
                self.visit(t, name, sum.attributes)

    def get_args(self, fields):
        """Return list of C argument into, one for each field.

        Argument info is 3-tuple of a C type, variable name, and flag
        that is true if type can be NULL.
        """
        args = []
        unnamed = {}
        for f in fields:
            if f.name is None:
                name = f.type
                c = unnamed[name] = unnamed.get(name, 0) + 1
                if c > 1:
                    name = "name%d" % (c - 1)
            else:
                name = f.name
            # XXX should extend get_c_type() to handle this
            if f.seq:
243 244 245 246
                if f.type.value in ('cmpop',):
                    ctype = "asdl_int_seq *"
                else:
                    ctype = "asdl_seq *"
Jeremy Hylton's avatar
Jeremy Hylton committed
247 248 249 250 251 252 253 254 255 256 257
            else:
                ctype = get_c_type(f.type)
            args.append((ctype, name, f.opt or f.seq))
        return args

    def visitConstructor(self, cons, type, attrs):
        args = self.get_args(cons.fields)
        attrs = self.get_args(attrs)
        ctype = get_c_type(type)
        self.emit_function(cons.name, ctype, args, attrs)

258
    def emit_function(self, name, ctype, args, attrs, union=True):
Jeremy Hylton's avatar
Jeremy Hylton committed
259 260 261 262
        args = args + attrs
        if args:
            argstr = ", ".join(["%s %s" % (atype, aname)
                                for atype, aname, opt in args])
263
            argstr += ", PyArena *arena"
Jeremy Hylton's avatar
Jeremy Hylton committed
264
        else:
265
            argstr = "PyArena *arena"
266 267 268 269
        margs = "a0"
        for i in range(1, len(args)+1):
            margs += ", a%d" % i
        self.emit("#define %s(%s) _Py_%s(%s)" % (name, margs, name, margs), 0,
270 271
                reflow=False)
        self.emit("%s _Py_%s(%s);" % (ctype, name, argstr), False)
Jeremy Hylton's avatar
Jeremy Hylton committed
272 273 274

    def visitProduct(self, prod, name):
        self.emit_function(name, get_c_type(name),
275
                           self.get_args(prod.fields), [], union=False)
Jeremy Hylton's avatar
Jeremy Hylton committed
276

277

Jeremy Hylton's avatar
Jeremy Hylton committed
278 279 280
class FunctionVisitor(PrototypeVisitor):
    """Visitor to generate constructor functions for AST."""

281 282
    def emit_function(self, name, ctype, args, attrs, union=True):
        def emit(s, depth=0, reflow=True):
Jeremy Hylton's avatar
Jeremy Hylton committed
283 284 285
            self.emit(s, depth, reflow)
        argstr = ", ".join(["%s %s" % (atype, aname)
                            for atype, aname, opt in args + attrs])
286 287 288 289
        if argstr:
            argstr += ", PyArena *arena"
        else:
            argstr = "PyArena *arena"
Jeremy Hylton's avatar
Jeremy Hylton committed
290 291 292 293 294 295
        self.emit("%s" % ctype, 0)
        emit("%s(%s)" % (name, argstr))
        emit("{")
        emit("%s p;" % ctype, 1)
        for argtype, argname, opt in args:
            # XXX hack alert: false is allowed for a bool
296
            if not opt and not (argtype == "bool" or argtype == "int"):
Jeremy Hylton's avatar
Jeremy Hylton committed
297 298 299 300
                emit("if (!%s) {" % argname, 1)
                emit("PyErr_SetString(PyExc_ValueError,", 2)
                msg = "field %s is required for %s" % (argname, name)
                emit('                "%s");' % msg,
301
                     2, reflow=False)
Jeremy Hylton's avatar
Jeremy Hylton committed
302 303 304
                emit('return NULL;', 2)
                emit('}', 1)

305
        emit("p = (%s)PyArena_Malloc(arena, sizeof(*p));" % ctype, 1);
306
        emit("if (!p)", 1)
Jeremy Hylton's avatar
Jeremy Hylton committed
307 308 309 310 311 312 313 314 315 316
        emit("return NULL;", 2)
        if union:
            self.emit_body_union(name, args, attrs)
        else:
            self.emit_body_struct(name, args, attrs)
        emit("return p;", 1)
        emit("}")
        emit("")

    def emit_body_union(self, name, args, attrs):
317
        def emit(s, depth=0, reflow=True):
Jeremy Hylton's avatar
Jeremy Hylton committed
318 319 320 321 322 323 324 325
            self.emit(s, depth, reflow)
        emit("p->kind = %s_kind;" % name, 1)
        for argtype, argname, opt in args:
            emit("p->v.%s.%s = %s;" % (name, argname, argname), 1)
        for argtype, argname, opt in attrs:
            emit("p->%s = %s;" % (argname, argname), 1)

    def emit_body_struct(self, name, args, attrs):
326
        def emit(s, depth=0, reflow=True):
Jeremy Hylton's avatar
Jeremy Hylton committed
327 328 329 330 331
            self.emit(s, depth, reflow)
        for argtype, argname, opt in args:
            emit("p->%s = %s;" % (argname, argname), 1)
        assert not attrs

332

Jeremy Hylton's avatar
Jeremy Hylton committed
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
class PickleVisitor(EmitVisitor):

    def visitModule(self, mod):
        for dfn in mod.dfns:
            self.visit(dfn)

    def visitType(self, type):
        self.visit(type.value, type.name)

    def visitSum(self, sum, name):
        pass

    def visitProduct(self, sum, name):
        pass

    def visitConstructor(self, cons, name):
        pass

    def visitField(self, sum):
        pass

354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377

class Obj2ModPrototypeVisitor(PickleVisitor):
    def visitProduct(self, prod, name):
        code = "static int obj2ast_%s(PyObject* obj, %s* out, PyArena* arena);"
        self.emit(code % (name, get_c_type(name)), 0)

    visitSum = visitProduct


class Obj2ModVisitor(PickleVisitor):
    def funcHeader(self, name):
        ctype = get_c_type(name)
        self.emit("int", 0)
        self.emit("obj2ast_%s(PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
        self.emit("{", 0)
        self.emit("PyObject* tmp = NULL;", 1)
        self.emit("", 0)

    def sumTrailer(self, name):
        self.emit("", 0)
        self.emit("tmp = PyObject_Repr(obj);", 1)
        # there's really nothing more we can do if this fails ...
        self.emit("if (tmp == NULL) goto failed;", 1)
        error = "expected some sort of %s, but got %%.400s" % name
378
        format = "PyErr_Format(PyExc_TypeError, \"%s\", PyString_AS_STRING(tmp));"
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
        self.emit(format % error, 1, reflow=False)
        self.emit("failed:", 0)
        self.emit("Py_XDECREF(tmp);", 1)
        self.emit("return 1;", 1)
        self.emit("}", 0)
        self.emit("", 0)

    def simpleSum(self, sum, name):
        self.funcHeader(name)
        for t in sum.types:
            self.emit("if (PyObject_IsInstance(obj, (PyObject*)%s_type)) {" % t.name, 1)
            self.emit("*out = %s;" % t.name, 2)
            self.emit("return 0;", 2)
            self.emit("}", 1)
        self.sumTrailer(name)

    def buildArgs(self, fields):
        return ", ".join(fields + ["arena"])

    def complexSum(self, sum, name):
        self.funcHeader(name)
        for a in sum.attributes:
            self.visitAttributeDeclaration(a, name, sum=sum)
        self.emit("", 0)
        # XXX: should we only do this for 'expr'?
        self.emit("if (obj == Py_None) {", 1)
        self.emit("*out = NULL;", 2)
        self.emit("return 0;", 2)
        self.emit("}", 1)
        for a in sum.attributes:
            self.visitField(a, name, sum=sum, depth=1)
        for t in sum.types:
            self.emit("if (PyObject_IsInstance(obj, (PyObject*)%s_type)) {" % t.name, 1)
            for f in t.fields:
                self.visitFieldDeclaration(f, t.name, sum=sum, depth=2)
            self.emit("", 0)
            for f in t.fields:
                self.visitField(f, t.name, sum=sum, depth=2)
            args = [f.name.value for f in t.fields] + [a.name.value for a in sum.attributes]
            self.emit("*out = %s(%s);" % (t.name, self.buildArgs(args)), 2)
            self.emit("if (*out == NULL) goto failed;", 2)
            self.emit("return 0;", 2)
            self.emit("}", 1)
        self.sumTrailer(name)

    def visitAttributeDeclaration(self, a, name, sum=sum):
        ctype = get_c_type(a.type)
        self.emit("%s %s;" % (ctype, a.name), 1)

    def visitSum(self, sum, name):
        if is_simple(sum):
            self.simpleSum(sum, name)
        else:
            self.complexSum(sum, name)

    def visitProduct(self, prod, name):
        ctype = get_c_type(name)
        self.emit("int", 0)
        self.emit("obj2ast_%s(PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
        self.emit("{", 0)
        self.emit("PyObject* tmp = NULL;", 1)
        for f in prod.fields:
            self.visitFieldDeclaration(f, name, prod=prod, depth=1)
        self.emit("", 0)
        for f in prod.fields:
            self.visitField(f, name, prod=prod, depth=1)
        args = [f.name.value for f in prod.fields]
        self.emit("*out = %s(%s);" % (name, self.buildArgs(args)), 1)
        self.emit("return 0;", 1)
        self.emit("failed:", 0)
        self.emit("Py_XDECREF(tmp);", 1)
        self.emit("return 1;", 1)
        self.emit("}", 0)
        self.emit("", 0)

    def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0):
        ctype = get_c_type(field.type)
        if field.seq:
            if self.isSimpleType(field):
                self.emit("asdl_int_seq* %s;" % field.name, depth)
            else:
                self.emit("asdl_seq* %s;" % field.name, depth)
        else:
            ctype = get_c_type(field.type)
            self.emit("%s %s;" % (ctype, field.name), depth)

    def isSimpleSum(self, field):
        # XXX can the members of this list be determined automatically?
        return field.type.value in ('expr_context', 'boolop', 'operator',
                                    'unaryop', 'cmpop')

    def isNumeric(self, field):
        return get_c_type(field.type) in ("int", "bool")

    def isSimpleType(self, field):
        return self.isSimpleSum(field) or self.isNumeric(field)

    def visitField(self, field, name, sum=None, prod=None, depth=0):
        ctype = get_c_type(field.type)
        self.emit("if (PyObject_HasAttrString(obj, \"%s\")) {" % field.name, depth)
        self.emit("int res;", depth+1)
        if field.seq:
            self.emit("Py_ssize_t len;", depth+1)
            self.emit("Py_ssize_t i;", depth+1)
        self.emit("tmp = PyObject_GetAttrString(obj, \"%s\");" % field.name, depth+1)
        self.emit("if (tmp == NULL) goto failed;", depth+1)
        if field.seq:
            self.emit("if (!PyList_Check(tmp)) {", depth+1)
            self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must "
                      "be a list, not a %%.200s\", tmp->ob_type->tp_name);" %
                      (name, field.name),
                      depth+2, reflow=False)
            self.emit("goto failed;", depth+2)
            self.emit("}", depth+1)
            self.emit("len = PyList_GET_SIZE(tmp);", depth+1)
            if self.isSimpleType(field):
                self.emit("%s = asdl_int_seq_new(len, arena);" % field.name, depth+1)
            else:
                self.emit("%s = asdl_seq_new(len, arena);" % field.name, depth+1)
            self.emit("if (%s == NULL) goto failed;" % field.name, depth+1)
            self.emit("for (i = 0; i < len; i++) {", depth+1)
            self.emit("%s value;" % ctype, depth+2)
            self.emit("res = obj2ast_%s(PyList_GET_ITEM(tmp, i), &value, arena);" %
                      field.type, depth+2, reflow=False)
            self.emit("if (res != 0) goto failed;", depth+2)
            self.emit("asdl_seq_SET(%s, i, value);" % field.name, depth+2)
            self.emit("}", depth+1)
        else:
            self.emit("res = obj2ast_%s(tmp, &%s, arena);" %
                      (field.type, field.name), depth+1)
            self.emit("if (res != 0) goto failed;", depth+1)

        self.emit("Py_XDECREF(tmp);", depth+1)
        self.emit("tmp = NULL;", depth+1)
        self.emit("} else {", depth)
        if not field.opt:
            message = "required field \\\"%s\\\" missing from %s" % (field.name, name)
            format = "PyErr_SetString(PyExc_TypeError, \"%s\");"
            self.emit(format % message, depth+1, reflow=False)
            self.emit("return 1;", depth+1)
        else:
            if self.isNumeric(field):
                self.emit("%s = 0;" % field.name, depth+1)
            elif not self.isSimpleType(field):
                self.emit("%s = NULL;" % field.name, depth+1)
            else:
                raise TypeError("could not determine the default value for %s" % field.name)
        self.emit("}", depth)


Jeremy Hylton's avatar
Jeremy Hylton committed
529 530 531 532
class MarshalPrototypeVisitor(PickleVisitor):

    def prototype(self, sum, name):
        ctype = get_c_type(name)
533
        self.emit("static int marshal_write_%s(PyObject **, int *, %s);"
Jeremy Hylton's avatar
Jeremy Hylton committed
534 535 536 537
                  % (name, ctype), 0)

    visitProduct = visitSum = prototype

538

539
class PyTypesDeclareVisitor(PickleVisitor):
Jeremy Hylton's avatar
Jeremy Hylton committed
540

541
    def visitProduct(self, prod, name):
542
        self.emit("static PyTypeObject *%s_type;" % name, 0)
543
        self.emit("static PyObject* ast2obj_%s(void*);" % name, 0)
544
        if prod.fields:
545
            self.emit("static char *%s_fields[]={" % name,0)
546 547 548
            for f in prod.fields:
                self.emit('"%s",' % f.name, 1)
            self.emit("};", 0)
Tim Peters's avatar
Tim Peters committed
549

550
    def visitSum(self, sum, name):
551
        self.emit("static PyTypeObject *%s_type;" % name, 0)
Martin v. Löwis's avatar
Martin v. Löwis committed
552
        if sum.attributes:
553
            self.emit("static char *%s_attributes[] = {" % name, 0)
Martin v. Löwis's avatar
Martin v. Löwis committed
554 555 556
            for a in sum.attributes:
                self.emit('"%s",' % a.name, 1)
            self.emit("};", 0)
557 558 559 560 561 562 563 564 565 566 567
        ptype = "void*"
        if is_simple(sum):
            ptype = get_c_type(name)
            tnames = []
            for t in sum.types:
                tnames.append(str(t.name)+"_singleton")
            tnames = ", *".join(tnames)
            self.emit("static PyObject *%s;" % tnames, 0)
        self.emit("static PyObject* ast2obj_%s(%s);" % (name, ptype), 0)
        for t in sum.types:
            self.visitConstructor(t, name)
Tim Peters's avatar
Tim Peters committed
568

569
    def visitConstructor(self, cons, name):
570
        self.emit("static PyTypeObject *%s_type;" % cons.name, 0)
571
        if cons.fields:
572
            self.emit("static char *%s_fields[]={" % cons.name, 0)
573 574 575
            for t in cons.fields:
                self.emit('"%s",' % t.name, 1)
            self.emit("};",0)
576 577 578 579 580

class PyTypesVisitor(PickleVisitor):

    def visitModule(self, mod):
        self.emit("""
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
static int
ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
{
    Py_ssize_t i, numfields = 0;
    int res = -1;
    PyObject *key, *value, *fields;
    fields = PyObject_GetAttrString((PyObject*)Py_TYPE(self), "_fields");
    if (!fields)
        PyErr_Clear();
    if (fields) {
        numfields = PySequence_Size(fields);
        if (numfields == -1)
            goto cleanup;
    }
    res = 0; /* if no error occurs, this stays 0 to the end */
    if (PyTuple_GET_SIZE(args) > 0) {
        if (numfields != PyTuple_GET_SIZE(args)) {
598
            PyErr_Format(PyExc_TypeError, "%.400s constructor takes %s"
599
                         "%zd positional argument%s",
600 601
                         Py_TYPE(self)->tp_name,
                         numfields == 0 ? "" : "either 0 or ",
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631
                         numfields, numfields == 1 ? "" : "s");
            res = -1;
            goto cleanup;
        }
        for (i = 0; i < PyTuple_GET_SIZE(args); i++) {
            /* cannot be reached when fields is NULL */
            PyObject *name = PySequence_GetItem(fields, i);
            if (!name) {
                res = -1;
                goto cleanup;
            }
            res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i));
            Py_DECREF(name);
            if (res < 0)
                goto cleanup;
        }
    }
    if (kw) {
        i = 0;  /* needed by PyDict_Next */
        while (PyDict_Next(kw, &i, &key, &value)) {
            res = PyObject_SetAttr(self, key, value);
            if (res < 0)
                goto cleanup;
        }
    }
  cleanup:
    Py_XDECREF(fields);
    return res;
}

632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
/* Pickling support */
static PyObject *
ast_type_reduce(PyObject *self, PyObject *unused)
{
    PyObject *res;
    PyObject *dict = PyObject_GetAttrString(self, "__dict__");
    if (dict == NULL) {
        if (PyErr_ExceptionMatches(PyExc_AttributeError))
            PyErr_Clear();
        else
            return NULL;
    }
    if (dict) {
        res = Py_BuildValue("O()O", Py_TYPE(self), dict);
        Py_DECREF(dict);
        return res;
    }
    return Py_BuildValue("O()", Py_TYPE(self));
}

static PyMethodDef ast_type_methods[] = {
    {"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
    {NULL}
};

657 658
static PyTypeObject AST_type = {
    PyVarObject_HEAD_INIT(&PyType_Type, 0)
659
    "_ast.AST",
660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684
    sizeof(PyObject),
    0,
    0,                       /* tp_dealloc */
    0,                       /* tp_print */
    0,                       /* tp_getattr */
    0,                       /* tp_setattr */
    0,                       /* tp_compare */
    0,                       /* tp_repr */
    0,                       /* tp_as_number */
    0,                       /* tp_as_sequence */
    0,                       /* tp_as_mapping */
    0,                       /* tp_hash */
    0,                       /* tp_call */
    0,                       /* tp_str */
    PyObject_GenericGetAttr, /* tp_getattro */
    PyObject_GenericSetAttr, /* tp_setattro */
    0,                       /* tp_as_buffer */
    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
    0,                       /* tp_doc */
    0,                       /* tp_traverse */
    0,                       /* tp_clear */
    0,                       /* tp_richcompare */
    0,                       /* tp_weaklistoffset */
    0,                       /* tp_iter */
    0,                       /* tp_iternext */
685
    ast_type_methods,        /* tp_methods */
686 687 688 689 690 691 692 693 694 695 696 697 698 699
    0,                       /* tp_members */
    0,                       /* tp_getset */
    0,                       /* tp_base */
    0,                       /* tp_dict */
    0,                       /* tp_descr_get */
    0,                       /* tp_descr_set */
    0,                       /* tp_dictoffset */
    (initproc)ast_type_init, /* tp_init */
    PyType_GenericAlloc,     /* tp_alloc */
    PyType_GenericNew,       /* tp_new */
    PyObject_Del,            /* tp_free */
};


700 701 702 703
static PyTypeObject* make_type(char *type, PyTypeObject* base, char**fields, int num_fields)
{
    PyObject *fnames, *result;
    int i;
704 705 706
    fnames = PyTuple_New(num_fields);
    if (!fnames) return NULL;
    for (i = 0; i < num_fields; i++) {
707
        PyObject *field = PyString_FromString(fields[i]);
708 709 710 711 712 713
        if (!field) {
            Py_DECREF(fnames);
            return NULL;
        }
        PyTuple_SET_ITEM(fnames, i, field);
    }
Tim Peters's avatar
Tim Peters committed
714
    result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){sOss}",
Martin v. Löwis's avatar
Martin v. Löwis committed
715
                    type, base, "_fields", fnames, "__module__", "_ast");
716 717 718 719
    Py_DECREF(fnames);
    return (PyTypeObject*)result;
}

Martin v. Löwis's avatar
Martin v. Löwis committed
720 721
static int add_attributes(PyTypeObject* type, char**attrs, int num_fields)
{
722
    int i, result;
723
    PyObject *s, *l = PyTuple_New(num_fields);
Martin v. Löwis's avatar
Martin v. Löwis committed
724
    if (!l) return 0;
725
    for(i = 0; i < num_fields; i++) {
726
        s = PyString_FromString(attrs[i]);
Martin v. Löwis's avatar
Martin v. Löwis committed
727 728 729 730
        if (!s) {
            Py_DECREF(l);
            return 0;
        }
731
        PyTuple_SET_ITEM(l, i, s);
Martin v. Löwis's avatar
Martin v. Löwis committed
732
    }
733 734 735
    result = PyObject_SetAttrString((PyObject*)type, "_attributes", l) >= 0;
    Py_DECREF(l);
    return result;
Martin v. Löwis's avatar
Martin v. Löwis committed
736 737
}

738 739
/* Conversion AST -> Python */

740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770
static PyObject* ast2obj_list(asdl_seq *seq, PyObject* (*func)(void*))
{
    int i, n = asdl_seq_LEN(seq);
    PyObject *result = PyList_New(n);
    PyObject *value;
    if (!result)
        return NULL;
    for (i = 0; i < n; i++) {
        value = func(asdl_seq_GET(seq, i));
        if (!value) {
            Py_DECREF(result);
            return NULL;
        }
        PyList_SET_ITEM(result, i, value);
    }
    return result;
}

static PyObject* ast2obj_object(void *o)
{
    if (!o)
        o = Py_None;
    Py_INCREF((PyObject*)o);
    return (PyObject*)o;
}
#define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object
static PyObject* ast2obj_bool(bool b)
{
    return PyBool_FromLong(b);
}
Martin v. Löwis's avatar
Martin v. Löwis committed
771

772
static PyObject* ast2obj_int(long b)
Martin v. Löwis's avatar
Martin v. Löwis committed
773 774 775
{
    return PyInt_FromLong(b);
}
776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799

/* Conversion Python -> AST */

static int obj2ast_object(PyObject* obj, PyObject** out, PyArena* arena)
{
    if (obj == Py_None)
        obj = NULL;
    if (obj)
        PyArena_AddPyObject(arena, obj);
    Py_XINCREF(obj);
    *out = obj;
    return 0;
}

#define obj2ast_identifier obj2ast_object
#define obj2ast_string obj2ast_object

static int obj2ast_int(PyObject* obj, int* out, PyArena* arena)
{
    int i;
    if (!PyInt_Check(obj) && !PyLong_Check(obj)) {
        PyObject *s = PyObject_Repr(obj);
        if (s == NULL) return 1;
        PyErr_Format(PyExc_ValueError, "invalid integer value: %.400s",
800
                     PyString_AS_STRING(s));
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817
        Py_DECREF(s);
        return 1;
    }

    i = (int)PyLong_AsLong(obj);
    if (i == -1 && PyErr_Occurred())
        return 1;
    *out = i;
    return 0;
}

static int obj2ast_bool(PyObject* obj, bool* out, PyArena* arena)
{
    if (!PyBool_Check(obj)) {
        PyObject *s = PyObject_Repr(obj);
        if (s == NULL) return 1;
        PyErr_Format(PyExc_ValueError, "invalid boolean value: %.400s",
818
                     PyString_AS_STRING(s));
819 820 821 822 823 824 825 826
        Py_DECREF(s);
        return 1;
    }

    *out = (obj == Py_True);
    return 0;
}

Benjamin Peterson's avatar
Benjamin Peterson committed
827
static int add_ast_fields(void)
828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843
{
    PyObject *empty_tuple, *d;
    if (PyType_Ready(&AST_type) < 0)
        return -1;
    d = AST_type.tp_dict;
    empty_tuple = PyTuple_New(0);
    if (!empty_tuple ||
        PyDict_SetItemString(d, "_fields", empty_tuple) < 0 ||
        PyDict_SetItemString(d, "_attributes", empty_tuple) < 0) {
        Py_XDECREF(empty_tuple);
        return -1;
    }
    Py_DECREF(empty_tuple);
    return 0;
}

844 845 846 847
""", 0, reflow=False)

        self.emit("static int init_types(void)",0)
        self.emit("{", 0)
848
        self.emit("static int initialized;", 1)
849
        self.emit("if (initialized) return 1;", 1)
850
        self.emit("if (add_ast_fields() < 0) return 0;", 1)
851 852
        for dfn in mod.dfns:
            self.visit(dfn)
Martin v. Löwis's avatar
Martin v. Löwis committed
853 854
        self.emit("initialized = 1;", 1)
        self.emit("return 1;", 1);
855 856 857
        self.emit("}", 0)

    def visitProduct(self, prod, name):
858 859 860 861
        if prod.fields:
            fields = name.value+"_fields"
        else:
            fields = "NULL"
862
        self.emit('%s_type = make_type("%s", &AST_type, %s, %d);' %
863
                        (name, name, fields, len(prod.fields)), 1)
Martin v. Löwis's avatar
Martin v. Löwis committed
864
        self.emit("if (!%s_type) return 0;" % name, 1)
Tim Peters's avatar
Tim Peters committed
865

866
    def visitSum(self, sum, name):
867 868
        self.emit('%s_type = make_type("%s", &AST_type, NULL, 0);' %
                  (name, name), 1)
Martin v. Löwis's avatar
Martin v. Löwis committed
869 870
        self.emit("if (!%s_type) return 0;" % name, 1)
        if sum.attributes:
Tim Peters's avatar
Tim Peters committed
871
            self.emit("if (!add_attributes(%s_type, %s_attributes, %d)) return 0;" %
Martin v. Löwis's avatar
Martin v. Löwis committed
872 873 874
                            (name, name, len(sum.attributes)), 1)
        else:
            self.emit("if (!add_attributes(%s_type, NULL, 0)) return 0;" % name, 1)
875 876 877
        simple = is_simple(sum)
        for t in sum.types:
            self.visitConstructor(t, name, simple)
Tim Peters's avatar
Tim Peters committed
878

879
    def visitConstructor(self, cons, name, simple):
880 881 882 883
        if cons.fields:
            fields = cons.name.value+"_fields"
        else:
            fields = "NULL"
Tim Peters's avatar
Tim Peters committed
884
        self.emit('%s_type = make_type("%s", %s_type, %s, %d);' %
885
                            (cons.name, cons.name, name, fields, len(cons.fields)), 1)
Martin v. Löwis's avatar
Martin v. Löwis committed
886
        self.emit("if (!%s_type) return 0;" % cons.name, 1)
887 888 889
        if simple:
            self.emit("%s_singleton = PyType_GenericNew(%s_type, NULL, NULL);" %
                             (cons.name, cons.name), 1)
Martin v. Löwis's avatar
Martin v. Löwis committed
890
            self.emit("if (!%s_singleton) return 0;" % cons.name, 1)
Tim Peters's avatar
Tim Peters committed
891

892

893 894 895
def parse_version(mod):
    return mod.version.value[12:-3]

Martin v. Löwis's avatar
Martin v. Löwis committed
896 897 898 899 900 901 902 903 904 905 906
class ASTModuleVisitor(PickleVisitor):

    def visitModule(self, mod):
        self.emit("PyMODINIT_FUNC", 0)
        self.emit("init_ast(void)", 0)
        self.emit("{", 0)
        self.emit("PyObject *m, *d;", 1)
        self.emit("if (!init_types()) return;", 1)
        self.emit('m = Py_InitModule3("_ast", NULL, NULL);', 1)
        self.emit("if (!m) return;", 1)
        self.emit("d = PyModule_GetDict(m);", 1)
907
        self.emit('if (PyDict_SetItemString(d, "AST", (PyObject*)&AST_type) < 0) return;', 1)
Martin v. Löwis's avatar
Martin v. Löwis committed
908 909
        self.emit('if (PyModule_AddIntConstant(m, "PyCF_ONLY_AST", PyCF_ONLY_AST) < 0)', 1)
        self.emit("return;", 2)
910
        # Value of version: "$Revision$"
911 912
        self.emit('if (PyModule_AddStringConstant(m, "__version__", "%s") < 0)'
                % parse_version(mod), 1)
913
        self.emit("return;", 2)
Martin v. Löwis's avatar
Martin v. Löwis committed
914 915 916 917 918 919
        for dfn in mod.dfns:
            self.visit(dfn)
        self.emit("}", 0)

    def visitProduct(self, prod, name):
        self.addObj(name)
Tim Peters's avatar
Tim Peters committed
920

Martin v. Löwis's avatar
Martin v. Löwis committed
921 922 923 924
    def visitSum(self, sum, name):
        self.addObj(name)
        for t in sum.types:
            self.visitConstructor(t, name)
Tim Peters's avatar
Tim Peters committed
925

Martin v. Löwis's avatar
Martin v. Löwis committed
926 927
    def visitConstructor(self, cons, name):
        self.addObj(cons.name)
Tim Peters's avatar
Tim Peters committed
928

Martin v. Löwis's avatar
Martin v. Löwis committed
929
    def addObj(self, name):
930
        self.emit('if (PyDict_SetItemString(d, "%s", (PyObject*)%s_type) < 0) return;' % (name, name), 1)
Jeremy Hylton's avatar
Jeremy Hylton committed
931

932

Jeremy Hylton's avatar
Jeremy Hylton committed
933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952
_SPECIALIZED_SEQUENCES = ('stmt', 'expr')

def find_sequence(fields, doing_specialization):
    """Return True if any field uses a sequence."""
    for f in fields:
        if f.seq:
            if not doing_specialization:
                return True
            if str(f.type) not in _SPECIALIZED_SEQUENCES:
                return True
    return False

def has_sequence(types, doing_specialization):
    for t in types:
        if find_sequence(t.fields, doing_specialization):
            return True
    return False


class StaticVisitor(PickleVisitor):
953 954 955 956 957
    CODE = '''Very simple, always emit this static code.  Overide CODE'''

    def visit(self, object):
        self.emit(self.CODE, 0, reflow=False)

958

959
class ObjVisitor(PickleVisitor):
Jeremy Hylton's avatar
Jeremy Hylton committed
960

Martin v. Löwis's avatar
Martin v. Löwis committed
961
    def func_begin(self, name):
Jeremy Hylton's avatar
Jeremy Hylton committed
962
        ctype = get_c_type(name)
963 964
        self.emit("PyObject*", 0)
        self.emit("ast2obj_%s(void* _o)" % (name), 0)
Jeremy Hylton's avatar
Jeremy Hylton committed
965
        self.emit("{", 0)
966 967 968 969 970 971
        self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
        self.emit("PyObject *result = NULL, *value = NULL;", 1)
        self.emit('if (!o) {', 1)
        self.emit("Py_INCREF(Py_None);", 2)
        self.emit('return Py_None;', 2)
        self.emit("}", 1)
Jeremy Hylton's avatar
Jeremy Hylton committed
972 973
        self.emit('', 0)

Martin v. Löwis's avatar
Martin v. Löwis committed
974
    def func_end(self):
975 976 977 978 979
        self.emit("return result;", 1)
        self.emit("failed:", 0)
        self.emit("Py_XDECREF(value);", 1)
        self.emit("Py_XDECREF(result);", 1)
        self.emit("return NULL;", 1)
Jeremy Hylton's avatar
Jeremy Hylton committed
980 981 982 983
        self.emit("}", 0)
        self.emit("", 0)

    def visitSum(self, sum, name):
984 985 986
        if is_simple(sum):
            self.simpleSum(sum, name)
            return
Martin v. Löwis's avatar
Martin v. Löwis committed
987
        self.func_begin(name)
988 989 990 991 992
        self.emit("switch (o->kind) {", 1)
        for i in range(len(sum.types)):
            t = sum.types[i]
            self.visitConstructor(t, i + 1, name)
        self.emit("}", 1)
Martin v. Löwis's avatar
Martin v. Löwis committed
993 994 995
        for a in sum.attributes:
            self.emit("value = ast2obj_%s(o->%s);" % (a.type, a.name), 1)
            self.emit("if (!value) goto failed;", 1)
996 997 998
            self.emit('if (PyObject_SetAttrString(result, "%s", value) < 0)' % a.name, 1)
            self.emit('goto failed;', 2)
            self.emit('Py_DECREF(value);', 1)
Martin v. Löwis's avatar
Martin v. Löwis committed
999
        self.func_end()
Tim Peters's avatar
Tim Peters committed
1000

1001 1002 1003 1004 1005 1006 1007 1008
    def simpleSum(self, sum, name):
        self.emit("PyObject* ast2obj_%s(%s_ty o)" % (name, name), 0)
        self.emit("{", 0)
        self.emit("switch(o) {", 1)
        for t in sum.types:
            self.emit("case %s:" % t.name, 2)
            self.emit("Py_INCREF(%s_singleton);" % t.name, 3)
            self.emit("return %s_singleton;" % t.name, 3)
1009 1010 1011 1012 1013
        self.emit("default:" % name, 2)
        self.emit('/* should never happen, but just in case ... */', 3)
        code = "PyErr_Format(PyExc_SystemError, \"unknown %s found\");" % name
        self.emit(code, 3, reflow=False)
        self.emit("return NULL;", 3)
1014 1015
        self.emit("}", 1)
        self.emit("}", 0)
Jeremy Hylton's avatar
Jeremy Hylton committed
1016 1017

    def visitProduct(self, prod, name):
Martin v. Löwis's avatar
Martin v. Löwis committed
1018
        self.func_begin(name)
1019 1020
        self.emit("result = PyType_GenericNew(%s_type, NULL, NULL);" % name, 1);
        self.emit("if (!result) return NULL;", 1)
Jeremy Hylton's avatar
Jeremy Hylton committed
1021 1022
        for field in prod.fields:
            self.visitField(field, name, 1, True)
Martin v. Löwis's avatar
Martin v. Löwis committed
1023
        self.func_end()
Tim Peters's avatar
Tim Peters committed
1024

Jeremy Hylton's avatar
Jeremy Hylton committed
1025 1026
    def visitConstructor(self, cons, enum, name):
        self.emit("case %s_kind:" % cons.name, 1)
1027 1028
        self.emit("result = PyType_GenericNew(%s_type, NULL, NULL);" % cons.name, 2);
        self.emit("if (!result) goto failed;", 2)
Jeremy Hylton's avatar
Jeremy Hylton committed
1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039
        for f in cons.fields:
            self.visitField(f, cons.name, 2, False)
        self.emit("break;", 2)

    def visitField(self, field, name, depth, product):
        def emit(s, d):
            self.emit(s, depth + d)
        if product:
            value = "o->%s" % field.name
        else:
            value = "o->v.%s.%s" % (name, field.name)
1040 1041 1042 1043 1044
        self.set(field, value, depth)
        emit("if (!value) goto failed;", 0)
        emit('if (PyObject_SetAttrString(result, "%s", value) == -1)' % field.name, 0)
        emit("goto failed;", 1)
        emit("Py_DECREF(value);", 0)
Jeremy Hylton's avatar
Jeremy Hylton committed
1045 1046

    def emitSeq(self, field, value, depth, emit):
1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
        emit("seq = %s;" % value, 0)
        emit("n = asdl_seq_LEN(seq);", 0)
        emit("value = PyList_New(n);", 0)
        emit("if (!value) goto failed;", 0)
        emit("for (i = 0; i < n; i++) {", 0)
        self.set("value", field, "asdl_seq_GET(seq, i)", depth + 1)
        emit("if (!value1) goto failed;", 1)
        emit("PyList_SET_ITEM(value, i, value1);", 1)
        emit("value1 = NULL;", 1)
        emit("}", 0)

    def set(self, field, value, depth):
        if field.seq:
1060
            # XXX should really check for is_simple, but that requires a symbol table
1061
            if field.type.value == "cmpop":
1062 1063 1064 1065 1066 1067 1068 1069
                # While the sequence elements are stored as void*,
                # ast2obj_cmpop expects an enum
                self.emit("{", depth)
                self.emit("int i, n = asdl_seq_LEN(%s);" % value, depth+1)
                self.emit("value = PyList_New(n);", depth+1)
                self.emit("if (!value) goto failed;", depth+1)
                self.emit("for(i = 0; i < n; i++)", depth+1)
                # This cannot fail, so no need for error handling
1070
                self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop((cmpop_ty)asdl_seq_GET(%s, i)));" % value,
1071
                          depth+2, reflow=False)
1072
                self.emit("}", depth)
1073
            else:
1074
                self.emit("value = ast2obj_list(%s, ast2obj_%s);" % (value, field.type), depth)
Jeremy Hylton's avatar
Jeremy Hylton committed
1075 1076
        else:
            ctype = get_c_type(field.type)
1077
            self.emit("value = ast2obj_%s(%s);" % (field.type, value), depth, reflow=False)
Tim Peters's avatar
Tim Peters committed
1078

Jeremy Hylton's avatar
Jeremy Hylton committed
1079

1080 1081 1082 1083 1084 1085 1086 1087
class PartingShots(StaticVisitor):

    CODE = """
PyObject* PyAST_mod2obj(mod_ty t)
{
    init_types();
    return ast2obj_mod(t);
}
1088

1089 1090
/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
1091 1092
{
    mod_ty res;
1093 1094 1095 1096 1097
    PyObject *req_type[] = {(PyObject*)Module_type, (PyObject*)Expression_type,
                            (PyObject*)Interactive_type};
    char *req_name[] = {"Module", "Expression", "Interactive"};
    assert(0 <= mode && mode <= 2);

1098
    init_types();
1099 1100 1101 1102

    if (!PyObject_IsInstance(ast, req_type[mode])) {
        PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s",
                     req_name[mode], Py_TYPE(ast)->tp_name);
1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113
        return NULL;
    }
    if (obj2ast_mod(ast, &res, arena) != 0)
        return NULL;
    else
        return res;
}

int PyAST_Check(PyObject* obj)
{
    init_types();
1114
    return PyObject_IsInstance(obj, (PyObject*)&AST_type);
1115
}
1116
"""
Jeremy Hylton's avatar
Jeremy Hylton committed
1117 1118 1119 1120 1121 1122 1123 1124

class ChainOfVisitors:
    def __init__(self, *visitors):
        self.visitors = visitors

    def visit(self, object):
        for v in self.visitors:
            v.visit(object)
1125
            v.emit("", 0)
Jeremy Hylton's avatar
Jeremy Hylton committed
1126

1127
common_msg = "/* File automatically generated by %s. */\n\n"
1128 1129 1130 1131 1132 1133 1134 1135 1136

c_file_msg = """
/*
   __version__ %s.

   This module must be committed separately after each AST grammar change;
   The __version__ number is set to the revision number of the commit
   containing the grammar change.
*/
1137

1138 1139
"""

Jeremy Hylton's avatar
Jeremy Hylton committed
1140
def main(srcfile):
1141
    argv0 = sys.argv[0]
1142 1143
    components = argv0.split(os.sep)
    argv0 = os.sep.join(components[-2:])
1144
    auto_gen_msg = common_msg % argv0
Jeremy Hylton's avatar
Jeremy Hylton committed
1145 1146 1147 1148 1149
    mod = asdl.parse(srcfile)
    if not asdl.check(mod):
        sys.exit(1)
    if INC_DIR:
        p = "%s/%s-ast.h" % (INC_DIR, mod.name)
1150
        f = open(p, "wb")
1151 1152
        f.write(auto_gen_msg)
        f.write('#include "asdl.h"\n\n')
1153 1154 1155 1156 1157
        c = ChainOfVisitors(TypeDefVisitor(f),
                            StructVisitor(f),
                            PrototypeVisitor(f),
                            )
        c.visit(mod)
1158 1159 1160
        f.write("PyObject* PyAST_mod2obj(mod_ty t);\n")
        f.write("mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode);\n")
        f.write("int PyAST_Check(PyObject* obj);\n")
1161
        f.close()
Jeremy Hylton's avatar
Jeremy Hylton committed
1162 1163

    if SRC_DIR:
1164
        p = os.path.join(SRC_DIR, str(mod.name) + "-ast.c")
1165
        f = open(p, "wb")
1166 1167 1168 1169 1170 1171
        f.write(auto_gen_msg)
        f.write(c_file_msg % parse_version(mod))
        f.write('#include "Python.h"\n')
        f.write('#include "%s-ast.h"\n' % mod.name)
        f.write('\n')
        f.write("static PyTypeObject AST_type;\n")
1172 1173 1174
        v = ChainOfVisitors(
            PyTypesDeclareVisitor(f),
            PyTypesVisitor(f),
1175
            Obj2ModPrototypeVisitor(f),
1176 1177
            FunctionVisitor(f),
            ObjVisitor(f),
1178
            Obj2ModVisitor(f),
1179 1180 1181 1182 1183
            ASTModuleVisitor(f),
            PartingShots(f),
            )
        v.visit(mod)
        f.close()
Jeremy Hylton's avatar
Jeremy Hylton committed
1184 1185 1186 1187 1188 1189 1190 1191

if __name__ == "__main__":
    import sys
    import getopt

    INC_DIR = ''
    SRC_DIR = ''
    opts, args = getopt.getopt(sys.argv[1:], "h:c:")
1192 1193 1194
    if len(opts) != 1:
        print "Must specify exactly one output file"
        sys.exit(1)
Jeremy Hylton's avatar
Jeremy Hylton committed
1195 1196 1197 1198 1199 1200 1201
    for o, v in opts:
        if o == '-h':
            INC_DIR = v
        if o == '-c':
            SRC_DIR = v
    if len(args) != 1:
        print "Must specify single input file"
1202
        sys.exit(1)
Jeremy Hylton's avatar
Jeremy Hylton committed
1203
    main(args[0])