-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcli.go
More file actions
132 lines (116 loc) · 2.73 KB
/
cli.go
File metadata and controls
132 lines (116 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package main
import (
"flag"
"fmt"
"io/ioutil"
"log"
"os/exec"
"strings"
"github.com/cloudspannerecosystem/memefish"
"github.com/cloudspannerecosystem/memefish/ast"
"github.com/cloudspannerecosystem/memefish/token"
)
const (
defaultFileBasename = "spanner_er"
exitCodeOK int = 0
exitCodeError = 10 + iota
exitCodeArgsError
)
type cli struct{}
func (cli *cli) run(args []string) int {
var (
help bool
file string
output string
t string
)
flags := flag.NewFlagSet("", flag.ContinueOnError)
flags.BoolVar(&help, "h", false, "print help")
flags.StringVar(&file, "s", "", "spanner schema file")
flags.StringVar(&output, "o", "", "output file name.default is spanner_er.<type>(pass to dot option -o)")
flags.StringVar(&t, "T", "png", "output file type. default is png(pass to dot option -T)")
if err := flags.Parse(args); err != nil {
return exitCodeArgsError
}
if help {
flags.Usage()
return exitCodeOK
}
if file == "" {
flags.Usage()
return exitCodeArgsError
}
if output == "" {
output = fmt.Sprintf("%s.%s", defaultFileBasename, t)
}
body, err := cli.read(file)
if err != nil {
log.Print(err)
return exitCodeError
}
tables, err := parse(body)
if err != nil {
log.Print(err)
return exitCodeError
}
graph, err := NewGraph()
if err != nil {
log.Print(err)
return exitCodeError
}
if err := graph.ApplyTables(tables); err != nil {
log.Print(err)
return exitCodeError
}
s := graph.String()
r := strings.NewReader(s)
c := exec.Command("dot", fmt.Sprintf("-T%s", t), "-o", output)
c.Stdin = r
c.Start()
c.Wait()
return exitCodeOK
}
func (cli *cli) read(file string) (string, error) {
data, err := ioutil.ReadFile(file)
if err != nil {
return "", err
}
body := string(data)
return body, nil
}
func parse(sqls string) ([]*ast.CreateTable, error) {
// Split the SQL by semicolons to get individual statements
statements := strings.Split(sqls, ";")
var tables []*ast.CreateTable
for _, stmt := range statements {
// Skip empty statements
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
// Create a new Parser instance for each statement
file := &token.File{
Buffer: stmt,
}
p := &memefish.Parser{
Lexer: &memefish.Lexer{File: file},
}
// Parse the statement
parsedStmt, err := p.ParseStatement()
if err != nil {
continue
}
// If it's a CREATE TABLE statement, add it to our list
if createTable, ok := parsedStmt.(*ast.CreateTable); ok {
tables = append(tables, createTable)
}
}
return tables, nil
}
// Helper function to get the name from a CreateTable
func getTableName(t *ast.CreateTable) string {
if t.Name != nil && len(t.Name.Idents) > 0 {
return t.Name.Idents[len(t.Name.Idents)-1].Name
}
return ""
}