#!/usr/bin/python

"""
 "sql2diagram-sxd" v. 1.0.5

    reverse engineer / generate a table structure diagram or
    an ER (Entity-Relationship) diagram in OpenOffice.org Draw
    format from SQL CREATE statements including foreign keys

 Copyright (C) 2003, 2004 by Jarno Elonen <elonen@iki.fi>

 This is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as
 published by the Free Software Foundation; either version 2,
 or (at your option) any later version.

 This is distributed in the hope that it will be useful, but
 WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 GNU General Public License for more details.

 You should have received a copy of the GNU General Public
 License along with the program; if not, write to the Free Software
 Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
"""

import sys, math, string, zipfile, cStringIO, getopt


def usagetext():

  """Return command line usage help string for the program."""

  return """
Usage:  sql2diagram-sxd [options] < SQL

where possible options are:
  -s, --struct                generate a table structure diagram (default)
  -e, --er                    generate an ER diagram
  -o <file>, --output=<file>  output to given file instead of STDOUT
  -h, --help                  show this help
  -t, --text                  write a textual description of the graph to stdout
  """

tables = {}

global explain
explain = False

# ---- begin yapps2 processed stuff ----
%%

# NOTE! To get a runnable Python program, you need
# to process this file with YAPPS2 parser generator by Amit Patel.
# http://theory.stanford.edu/~amitp/Yapps/

parser SqlCreate:
    ignore:       "[ \r\t\n]+"
    token END:    "$"
    token NUM:    "[0-9]+"
    token VAR:    '[a-zA-Z0-9\/_]+'
    token SQUOTE: "[`´'\"]"

    rule goal:       ( create             {{ tab = create }}
                                          {{ tables[tab["name"]] = tab }}
                       | comment          {{ pass }}
                     )+ END               {{ return tables }}

    rule col_attribs<<table, colname>>: (
                          [["NOT"] "NULL"]                                  {{ pass }}
                          [("DEFAULT"
                            str                                             {{ pass }}
                            ) ]                                             {{ pass }}
                          ["auto_increment"]                                {{ pass }}
                          ["PRIMARY" "KEY"                                  {{ table["primary"].append(colname) }}
                          ]                                                 {{ pass }}
                          ["REFERENCES"                                     {{ ref = {"fromcols":[colname]} }}
                            str                                             {{ ref["totable"] = str; ref["tocols"] = [colname] }}
                            [strlist                                        {{ ref["tocols"] = strlist }}
                            ]
                            ["MATCH" ("FULL"|"PARTIAL")]
                            ("ON" ("DELETE"|"UPDATE") referenceoption)*     {{ table["refs"].append(ref) }}
                          ]
                       )

    rule single_col<<table>>: (
                      ( str                                               {{ colname = str }}
                        colspec                                           {{ table["cols"][colname] = {"type":colspec, "primary":False } }}
                        col_attribs<<table, colname>>
                      )

                      | ("PRIMARY" "KEY" strlist )                                  {{ table["primary"] = strlist }}

                      | (["CONSTRAINT" str] "FOREIGN" "KEY" optionalname strlist    {{ ref = {"name":optionalname, "fromcols":strlist} }}
                         "REFERENCES" str                                           {{ ref["totable"] = str }}
                         strlist ["MATCH" ("FULL"|"PARTIAL")]                       {{ ref["tocols"] = strlist }}
                         ("ON" ("DELETE"|"UPDATE") referenceoption)* )              {{ table["refs"].append(ref); }}

                      | ("KEY" |"INDEX") optionalname strlist
                      | "UNIQUE" ["INDEX"|"KEY"] optionalname strlist
                      | "FULLTEXT" ["INDEX"] optionalname strlist
                     )

    rule create:     "CREATE" "(TEMPORARY)?" "TABLE" "(IF NOT EXISTS)?" str "\("  {{ table = { "name":str, "primary":[], "cols":{}, "refs":[] }; table["primary"] = [] }}
                     single_col<<table>>
                     ( comment | ( "," single_col<<table>> ) )*
                     "\)" [table_option] ";"                                       {{ return table }}

    rule referenceoption: "RESTRICT" | "CASCADE" |
                        | "NO" "ACTION" | "SET" ("NULL"|"DEFAULT")

   rule table_option: (("TYPE" | "AUTO_INCREMENT" | "AVG_ROW_LENGTH" |
                        "CHECKSUM" | "COMMENT" | "MAX_ROWS" |
                        "MIN_ROWS" | "PACK_KEYS" | "PASSWORD" |
                        "DELAY_KEY_WRITE" | "ROW_FORMAT" |
                        ("RAID_TYPE" "=" str "RAID_CHUNKS" "=" str "RAID_CHUNKSIZE") |
                        "INSERT_METHOD" | "DATA DIRECTORY" | "INDEX DIRECTORY" )
                        "=" str )                                                      {{ pass }}
                      | ("UNION" "=" strlist)                                          {{ pass }}

    rule num_options:   ["unsigned"] ["zerofill"]

    rule colspec: ( "tinyint" [numlist1] num_options                    {{ val = "tinyint" }}
                 | "smallint" [numlist1] num_options                    {{ val = "smallint" }}
                 | "mediumint" [numlist1] num_options                   {{ val = "mediumint" }}
                 | "int" [numlist1] num_options                         {{ val = "int" }}
                 | "integer" [numlist1] num_options                     {{ val = "integer" }}
                 | "bigint" [numlist1] num_options                      {{ val = "bigint" }}
                 | "real" [numlist2] num_options                        {{ val = "real" }}
                 | "double" [numlist2] num_options                      {{ val = "double" }}
                 | "float" [numlist2] num_options                       {{ val = "float" }}
                 | "decimal" numlist2 num_options                       {{ val = "decimal" }}
                 | "numeric" numlist2 num_options                       {{ val = "numeric" }}
                 | "char" numlist1 ["binary" | "ascii" | "unicode"]     {{ val = "char" }}
                 | "varchar" numlist1 ["binary"]                        {{ val = "varchar" }}
                 | "date"                                               {{ val = "date" }}
                 | "time"                                               {{ val = "time" }}
                 | "timestamp(\([0-9]*\))?"                             {{ val = "timestamp" }}
                 | "datetime"                                           {{ val = "datetime" }}
                 | "tinyblob"                                           {{ val = "tinyblob" }}
                 | "blob"                                               {{ val = "blob" }}
                 | "mediumblob"                                         {{ val = "mediumblob" }}
                 | "longblob"                                           {{ val = "longblob" }}
                 | "tinytext"                                           {{ val = "tinytext" }}
                 | "text"                                               {{ val = "text" }}
                 | "mediumtext"                                         {{ val = "mediumtext" }}
                 | "longtext"                                           {{ val = "longtext" }}
                 | "enum" strlist                                       {{ val = "enum" }}
                 | "set" strlist                                        {{ val = "set" }}
                 )                                                      {{ return val }}

    rule comment:    "\/\*([^\/]|\/[^\*])*\*\/"
                     | "--.*"

    rule str:        VAR                              {{ return VAR }}
                     |( SQUOTE                        {{ val = "" }}
                        [VAR                          {{ val = VAR }}
                        ] SQUOTE )                    {{ return val }}

    rule optionalname: (                 {{ return "" }}
                         |str)            {{ return str }}

    rule strlist:    ( "\("                {{ val = [] }}
                       ( str ",?"          {{ val.append( str ) }}
                       )+ "\)")            {{ return val }}

    rule numlist1:   "\(" NUM "\)"         {{ return NUM }}

    rule numlist2:   (
                      "\(" NUM ","         {{ a = NUM }}
                           NUM             {{ b = NUM }}
                      "\)")                {{ return (a,b) }}

%%
# ---- end yapps2 processed stuff ----


def merge_ooo_doc(doc, fields, data):

    """Merge custom fields into a template OpenOffice.org document.

       The function takes the document template (zip) as a string,
       a list of (tag, key) pairs, and a list of mappings.

       For each (tag, key) pair, it replaces the first instance
       of the tag string with the value of the key in the first
       dictionary, the second instance with the value from the
       second dictionary, etc. It returns the resulting document
       as a string.

       For example, a designer could use use placeholders such as
       '$$CODE$$' and '$$DESCRIPTION$$'. The 'fields' might then look
       like '[("$$CODE$$", "CODE"), ("$$DESCRIPTION$$", "DESCRIPTION"), ...]',
       and the data would be [{"CODE":code_str, "DESCRIPTION:description_str"}].

       This function was borrowed from:
       http://www.zopelabs.com/cookbook/1043777422"""

    docf = cStringIO.StringIO(doc)
    zf = zipfile.ZipFile(docf)
    txt = zf.read('content.xml')

    for (field, k) in fields:
        parts = txt.split(field)
        newparts = [parts.pop(0)]
        i = 0
        while parts and i < len(data):
            newparts.append(str(data[i][k]))
            newparts.append(parts.pop(0))
            i = i + 1
        newparts.extend(parts)
        txt = ''.join(newparts)

    out = cStringIO.StringIO()
    zf2 = zipfile.ZipFile(out, 'w')
    for zi in zf.infolist():
        if zi.filename == 'content.xml':
            zf2.writestr(zi, txt)
        else:
            zf2.writestr(zi, zf.read(zi.filename))
    zf2.close()
    zf.close()
    return out.getvalue()

def ooo_text( texts ):

    """Return OOo Draw XML for given text strings. "texts" must
       be a list of tuples (style, text), where "style"
       is a text style ID defined in the template document."""

    res = "<text:p text:style-name=\"P1\">";
    for t,s in texts:
      res += ('<text:span text:style-name="T%s">%s</text:span>' % (s,t))
    return res + "</text:p>"

def explain_texts( texts ):
  exp = ""
  for t,s in texts:
    exp += "text (style %s):'%s'" % (s, t)
  return exp

def text_box( shape_id, x, y, w, h, texts ):

    """Return OOo Draw XML for a text box of given boundaries and text(s)."""

    if explain:
      print "  BOX  id: #%d coords: %f, %f size: %f, %f %s" % \
        (shape_id, x,y, w,h, explain_texts(texts))

    return ('<draw:rect draw:style-name="grhollowobject" draw:text-style-name="P1" ' + \
      "draw:id=\"%d\" draw:layer=\"layout\" svg:width=\"%fcm\" svg:height=\"%fcm\" " + \
      "svg:x=\"%fcm\" svg:y=\"%fcm\">" + ooo_text(texts) + "</draw:rect>" ) % (shape_id, w,h, x,y)

def arrow_filled( from_id, from_x, from_y, to_id, to_x, to_y ):

    """Return OOo Draw XML for a filled arrow, connecting drawing element
       'from_id' to 'to_id' at given coordinates."""

    if explain:
      print "  ARROW (FILLED)  #%d -> #%d (coords %f, %f -> %f, %f)" % \
        (from_id, to_id, from_x, from_y, to_x, to_y)

    line_type = "curve";
    return ('<draw:connector draw:style-name="grarrowfilled" draw:text-style-name="P7" ' + \
      'draw:layer="layout" draw:type="%s" svg:x1="%fcm" svg:y1="%fcm" ' + \
      'svg:x2="%fcm" svg:y2="%fcm" draw:start-shape="%d" draw:start-glue-point="1" ' + \
      'draw:end-shape="%d" draw:end-glue-point="3"/>') % \
        (line_type, from_x, from_y, to_x, to_y, from_id, to_id );

def connector_plain( from_id, to_id ):

    """Return OOo Draw XML for a plain line connector."""

    if explain:
      print "  PLAIN CONNECTOR  #%d -> #%d " % (from_id, to_id)
    line_type = "line";
    return ('<draw:connector draw:style-name="grlineonly" draw:text-style-name="P7" ' + \
      'draw:layer="layout" draw:type="%s" ' + \
      'draw:start-shape="%d" draw:end-shape="%d"/>') % \
        (line_type, from_id, to_id );

def connector_line_arrow( from_id, to_id ):

    """Return OOo Draw XML for a line arrow connector."""

    if explain:
      print "  LINE ARROW CONNECTOR  #%d -> #%d" % (from_id, to_id)

    line_type = "line";
    return ('<draw:connector draw:style-name="grarrowlines" draw:text-style-name="P7" ' + \
      'draw:layer="layout" draw:type="%s" ' + \
      'draw:start-shape="%d" draw:end-shape="%d"/>') % \
        (line_type, from_id, to_id );

def connector_filled_arrow( from_id, to_id ):

    """Return OOo Draw XML for a filled arrow connector."""

    if explain:
      print "  FILLED ARROW CONNECTOR  #%d -> #%d" % (from_id, to_id)

    line_type = "line";
    return ('<draw:connector draw:style-name="grarrowfilled" draw:text-style-name="P7" ' + \
      'draw:layer="layout" draw:type="%s" ' + \
      'draw:start-shape="%d" draw:end-shape="%d"/>') % \
        (line_type, from_id, to_id );

def er_relationship( id, x, y, texts ):

    """Return OOo Draw XML for an ER relationship symbol with given text."""

    if explain:
        print "  RELATIONSHIP  id: #%d coords: %f, %f %s" % (id, x,y, explain_texts(texts))

    return ('<draw:polygon draw:id="%d" draw:style-name="grhollowobject" draw:text-style-name="P1" ' + \
           'draw:layer="layout" svg:width="1.791cm" svg:height="0.77cm" svg:x="%fcm" svg:y="%fcm" ' + \
           'svg:viewBox="0 0 1791 770" draw:points="0,384 895,0 1790,384 895,769">' + \
           ooo_text(texts) + "</draw:polygon>") % (id, x-1.791/2, y-0.77/2)

def er_weak_relationship( id, x, y, texts ):

    """Return OOo Draw XML for an ER weak relationship symbol with given text."""

    if explain:
        print "  WEAK RELATIONSHIP  id: #%d coords: %f, %f %s" % (id, x,y, explain_texts(texts))

    return ('<draw:path draw:style-name="grhollowobject" draw:text-style-name="P1" draw:id="%d" ' + \
            'draw:layer="layout" svg:width="1.791cm" svg:height="0.77cm" svg:x="%fcm" svg:y="%fcm" ' + \
            'svg:viewBox="0 0 1791 770" svg:d="m0 384c298-127 596-256 895-384 298 128 596 257 895 384-298 ' + \
            '128-597 256-895 385-299-129-597-257-895-385zm161 1c244-105 489-210 734-315 244 104 488 209 ' + \
            '733 315-244 104-488 209-733 314-245-104-489-209-734-314z">' + \
            ooo_text(texts) + "</draw:path>") % (id, x-1.791/2, y-0.77/2)

def er_entity( id, x, y, texts ):

    """Return OOo Draw XML for an ER entity symbol with given text."""

    if explain:
        print "  ENTITY  id: #%d coords: %f, %f %s" % (id, x,y, explain_texts(texts))

    return ('<draw:rect draw:style-name="grhollowobject" draw:text-style-name="P1" draw:id="%d" draw:layer="layout" '+ \
            'svg:width="1.773cm" svg:height="0.761cm" svg:x="%fcm" svg:y="%fcm">' + \
            ooo_text(texts) + "</draw:rect>") % (id, x-1.773/2, y-0.761/2)

def er_weak_entity( id, x, y, texts ):

    """Return OOo Draw XML for an ER weak entity symbol with given text."""

    if explain:
        print "  WEAK ENTITY  id: #%d coords: %f, %f %s" % (id, x,y, explain_texts(texts))

    return ('<draw:path draw:style-name="grhollowobject" draw:text-style-name="P1" draw:id="%d" ' + \
            'draw:layer="layout" svg:width="1.773cm" svg:height="0.761cm" svg:x="%fcm" svg:y="%fcm" ' + \
            'svg:viewBox="0 0 1773 761" svg:d="m0 760c0-253 0-506 0-760 590 0 1180 0 1772 0 0 253 0 506 ' + \
            '0 760-590 0-1180 0-1772 0zm72-64c0-210 0-420 0-632 542 0 1084 0 1628 0 0 210 0 420 0 632-542 ' + \
            '0-1084 0-1628 0z">' + ooo_text(texts) + '</draw:path>') % (id, x-1.773/2, y-0.761/2)

def er_attribute( id, x, y, texts ):

    """Return OOo Draw XML for an ER attribute symbol with given text."""

    if explain:
        print "  ATTRIBUTE  id: #%d coords: %f, %f %s" % (id, x,y, explain_texts(texts))

    return ('<draw:ellipse draw:style-name="grhollowobject" draw:text-style-name="P2" draw:id="%d" ' + \
            'draw:layer="layout" svg:width="1.52cm" svg:height="0.507cm" svg:x="%fcm" svg:y="%fcm">' + \
            ooo_text(texts) + '</draw:ellipse>') % (id, x-1.52/2, y-0.507/2)

def build_er_diagram( parsed ):

    """Maps the results from the SQL parser into
       an ER diagram (OpenOffice.org drawing) and if
       the global 'explain' flag is True,  prints out a
       textual list of the elements it generates."""

    res = ""
    cur_id = 1
    cols_by_id = {}
    cr = 2.0
    sx = cr
    sy = cr
    x = sx

    # Assign a unique ID for each table and column:
    for tabname, tab in parsed.items():
      tab["id"] = cur_id
      cur_id += 1
      for k, c in tab["cols"].items():
        c["id"] = cur_id
        c["parent"] = tab
        c["name"] = k
        cols_by_id[cur_id] = tab["cols"][k]
        cur_id += 1

    for tabname, tab in parsed.items():
      y = sy
      tab["x"] = x;
      tab["y"] = y; 
      tab["weak"] = False;

      for cname in tab["primary"]:
        tab["cols"][cname]["primary"] = True

      # Calculate locations for the columns
      a = 0
      for k, c in tab["cols"].items():
        c["x"] = x + math.cos(a) * cr
        c["y"] = y + math.sin(a) * cr
        a += (3.1416*2) / len(tab["cols"]);

      # Store references to columns
      for cname,c in tab["cols"].items():
        c["refto"] = -1
      for fk in tab["refs"]:        
        fromcol = tab["cols"][fk["fromcols"][0]]
        if fromcol["primary"]:
          tab["weak"] = True
        tocol = tables[fk["totable"]]["cols"][fk["tocols"][0]]
        fromcol["refto"] = tocol["id"]

      # Wrap to the next line, if necessary
      x += cr * 3
      if x > 40:
        x = sx
        sy += cr*3

    # Draw the actual diagram
    for tabname, tab in parsed.items():
      if explain:
        print "TABLE '%s'" % tabname

      group_res = "";
      free_res = "";

      if tab["weak"] and len(tab["primary"])==2 and len(tab["refs"])==2:

        # Table is clearly a connecting set of two tables, represent it as
        # an n-to-m relationship (possibly with extra attributes) instead
        # of a weak entity:

        (fk_a, fk_b) = tab["refs"]
        tab["x"] = (tables[fk_a["totable"]]["x"] + tables[fk_b["totable"]]["x"])/2
        tab["y"] = (tables[fk_a["totable"]]["y"] + tables[fk_b["totable"]]["y"])/2

        group_res += er_relationship( tab["id"], tab["x"], tab["y"], [(tab["name"], "COLNAME")])
        free_res += connector_plain( tab["id"], tables[fk_a["totable"]]["id"] )
        free_res += connector_plain( tab["id"], tables[fk_b["totable"]]["id"] )

        text_style = c["primary"] and "KEYCOLNAME" or "COLNAME"        
        for k, c in tab["cols"].items():
          if not c["primary"]:
            label = [(k, text_style)]
            group_res += er_attribute( c["id"], c["x"], c["y"], label)
            group_res += connector_plain( tab["id"], c["id"] )

      else:

        # Table is either a normal entity or a larger connecting set, represent it
        # as an entity (either weak or strong):

        ent_func = tab["weak"] and er_weak_entity or er_entity
        group_res += ent_func( tab["id"], tab["x"], tab["y"], [(tab["name"], "COLNAME")])

        for k, c in tab["cols"].items():
          text_style = c["primary"] and "KEYCOLNAME" or "COLNAME"        

          if c["refto"] > -1:

            # Interpret columns with a foreign key as 1-to-n relationships

            label = [(k, text_style)]
            tocol = cols_by_id[c["refto"]]
            rel_func = c["primary"] and er_weak_relationship or er_relationship

            # If the name of the column is same in both ends of the foreign key
            # reference, use "<target-table>-of" as a name for the relationship
            if k == tocol["name"]:
              label = [(tocol["parent"]["name"] + "-of", "COLNAME")]

            free_res += rel_func( c["id"], (c["x"]+tocol["parent"]["x"])/2, \
                                  (c["y"]+tocol["parent"]["y"])/2, label)
            free_res += connector_plain( tab["id"], c["id"] )
            free_res += connector_filled_arrow( c["id"], tocol["parent"]["id"] )

          else:
            # Normal attribute
            label = [(k, text_style)]
            group_res += er_attribute( c["id"], c["x"], c["y"], label)
            group_res += connector_plain( tab["id"], c["id"] )

      if len(group_res) > 0:
        res += "<draw:g>" + group_res + "</draw:g>"
      res += free_res

    return res


def build_table_diagram( parsed ):

    """Maps the results from the SQL parser into
       a table diagram (OpenOffice.org drawing) and if
       the global 'explain' flag is True,  prints out a
       textual list of the elements it generates."""

    res = ""
    cur_id = 1
    cols_by_id = {}
    sx = 1.0
    sy = 1.0
    cw = 3.5
    ch = 0.5
    x = sx

    max_h = -1
    for tabname, tab in parsed.items():
      y = sy
      tab["x"] = x
      tab["y"] = y
      y += ch
      h = ch
      for c in tab["primary"]:
        tab["cols"][c]["primary"] = True;
      for k, c in tab["cols"].items():
        tab["cols"][k]["id"] = cur_id
        tab["cols"][k]["x"] = x
        tab["cols"][k]["y"] = y
        cols_by_id[cur_id] = tab["cols"][k]
        cur_id += 1
        y += ch
        h += ch
      if h > max_h:
        max_h = h
      x += cw * 1.5
      if x > 19:
        x = sx
        sy += max_h + ch*2
        max_h = 0

    for tabname, tab in parsed.items():
      if explain:
        print "TABLE '%s'" % tabname

      res += "<draw:g>"
      res += text_box( cur_id, tab["x"],tab["y"], cw,ch, [(tab["name"], "TABNAME")])
      cur_id += 1
      for k, c in tab["cols"].items():
        style = c["primary"] and "KEYCOLNAME" or "COLNAME"
        res += text_box( c["id"], c["x"], c["y"], cw,ch, [(k, style), (" (" + c["type"] + ")", "TYPENAME")])
      res += "</draw:g>"

      for fk in tab["refs"]:
        # TODO: warn the user if this discards columns from the list:
        fromcol = tab["cols"][fk["fromcols"][0]]
        tocol = tables[fk["totable"]]["cols"][fk["tocols"][0]]
        res += arrow_filled( fromcol["id"], fromcol["x"]+cw, fromcol["y"]+ch/2, \
                      tocol["id"], tocol["x"], tocol["y"]+ch/2 )

    return res


if __name__=='__main__':

    # Parse options & arguments

    try:
      opts, args = getopt.getopt(sys.argv[1:], "hesto:",
        ["help","er", "struct", "text", "output="])
    except getopt.GetoptError, e:
      sys.stderr.write( "Error: " + str(e) + "\n" )
      sys.stderr.write( usagetext() + "\n" )
      sys.exit(2)

    outfile = False
    format_func = build_table_diagram

    for o, a in opts:
      if o in ("-h", "--help"):
        print usagetext()
        sys.exit()
      elif o in ("-t", "--text"):
        explain = True
      elif o in ("-o", "--output"):
        outfile = a
      elif o in ("-e", "--er"):
        format_func = build_er_diagram
      elif o in ("-s", "--struct"):
        format_func = build_table_diagram

    if explain and not outfile:
      sys.stderr.write("ERROR: You must also specify -o when using -t.\n")
      sys.exit(2)

    # Read the SQL statements from STDIN,
    # parse and generate the drawing

    while 1:
        try: s = sys.stdin.read()
     except EOFError: break
        if not string.strip(s): break

        parsed = parse('goal', s)
        content_str = format_func( parsed )

        f = open("sxd-template.sxd", "r")
        template = f.read()
        f.close()

        outstr = merge_ooo_doc(template, [("$$CONTENT$$", "CONTENT")], [{"CONTENT":content_str}])
        if outfile:
          f = open(outfile, "w")
          f.write( outstr )
          f.close()
        else:
          sys.stdout.write( outstr )