Files
musenalm/helpers/imports/data.go
2026-01-28 19:37:19 +01:00

253 lines
5.1 KiB
Go

package imports
import (
"archive/zip"
"encoding/xml"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
)
type tableFile struct {
Name string
Path string
Zip *zip.File
}
type columnValue struct {
Value string
IsNull bool
}
func ImportData(app core.App, candidate *ImportCandidate) error {
if candidate == nil {
return fmt.Errorf("no data import candidate found")
}
files, err := listDataFiles(candidate)
if err != nil {
return err
}
if len(files) == 0 {
return fmt.Errorf("no XML files found in data import")
}
tableOrder := make([]string, 0, len(files))
for _, file := range files {
tableOrder = append(tableOrder, file.Name)
}
return app.RunInTransaction(func(txApp core.App) error {
for i := len(tableOrder) - 1; i >= 0; i-- {
if shouldSkipImportTable(tableOrder[i]) {
continue
}
if _, err := txApp.DB().NewQuery("DELETE FROM " + quoteTableName(tableOrder[i])).Execute(); err != nil {
return err
}
}
for _, file := range files {
if shouldSkipImportTable(file.Name) {
continue
}
reader, closeFn, err := openTableReader(file)
if err != nil {
return err
}
if err := importTableXML(txApp, reader); err != nil {
closeFn()
return err
}
closeFn()
}
return nil
})
}
func listDataFiles(candidate *ImportCandidate) ([]tableFile, error) {
files := []tableFile{}
if candidate.IsZip {
reader, err := zip.OpenReader(candidate.Path)
if err != nil {
return nil, err
}
defer reader.Close()
for _, file := range reader.File {
if file.FileInfo().IsDir() || !strings.HasSuffix(strings.ToLower(file.Name), ".xml") {
continue
}
files = append(files, tableFile{
Name: strings.TrimSuffix(filepath.Base(file.Name), ".xml"),
Zip: file,
})
}
} else {
entries, err := os.ReadDir(candidate.Path)
if err != nil {
return nil, err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(strings.ToLower(name), ".xml") {
continue
}
files = append(files, tableFile{
Name: strings.TrimSuffix(name, ".xml"),
Path: filepath.Join(candidate.Path, name),
})
}
}
sort.Slice(files, func(i, j int) bool {
return files[i].Name < files[j].Name
})
return files, nil
}
func openTableReader(file tableFile) (io.ReadCloser, func(), error) {
if file.Zip != nil {
reader, err := file.Zip.Open()
if err != nil {
return nil, func() {}, err
}
return reader, func() { _ = reader.Close() }, nil
}
reader, err := os.Open(file.Path)
if err != nil {
return nil, func() {}, err
}
return reader, func() { _ = reader.Close() }, nil
}
func importTableXML(app core.App, reader io.Reader) error {
decoder := xml.NewDecoder(reader)
var tableName string
inRow := false
rowValues := map[string]columnValue{}
for {
token, err := decoder.Token()
if err == io.EOF {
break
}
if err != nil {
return err
}
switch t := token.(type) {
case xml.StartElement:
if t.Name.Local == "table" {
for _, attr := range t.Attr {
if attr.Name.Local == "name" {
tableName = strings.TrimSpace(attr.Value)
break
}
}
if tableName == "" {
return fmt.Errorf("missing table name in XML")
}
continue
}
if t.Name.Local == "row" {
inRow = true
rowValues = map[string]columnValue{}
continue
}
if inRow {
colName := t.Name.Local
isNull := hasNullAttr(t.Attr)
var text string
if err := decoder.DecodeElement(&text, &t); err != nil {
return err
}
rowValues[colName] = columnValue{
Value: text,
IsNull: isNull,
}
}
case xml.EndElement:
if t.Name.Local == "row" && inRow {
if err := insertRow(app, tableName, rowValues); err != nil {
return err
}
inRow = false
}
}
}
return nil
}
func insertRow(app core.App, tableName string, row map[string]columnValue) error {
if len(row) == 0 {
return nil
}
cols := make([]string, 0, len(row))
placeholders := make([]string, 0, len(row))
params := dbx.Params{}
idx := 0
for col, value := range row {
cols = append(cols, quoteIdentifier(col))
paramName := fmt.Sprintf("p%d", idx)
placeholders = append(placeholders, "{:"+paramName+"}")
if value.IsNull {
params[paramName] = nil
} else {
params[paramName] = value.Value
}
idx++
}
query := "INSERT INTO " + quoteTableName(tableName) +
" (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
_, err := app.DB().NewQuery(query).Bind(params).Execute()
return err
}
func hasNullAttr(attrs []xml.Attr) bool {
for _, attr := range attrs {
if attr.Name.Local == "null" && strings.EqualFold(attr.Value, "true") {
return true
}
}
return false
}
func shouldSkipImportTable(name string) bool {
if name == "" {
return true
}
if strings.HasPrefix(name, "_") {
return true
}
switch name {
case "schema_migrations":
return true
}
return false
}
func quoteTableName(name string) string {
return "`" + strings.ReplaceAll(name, "`", "``") + "`"
}
func quoteIdentifier(name string) string {
return "`" + strings.ReplaceAll(name, "`", "``") + "`"
}