diff --git a/src/lib.rs b/src/lib.rs index a3ad96f..3d794f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -660,6 +660,46 @@ mod tests { assert_eq!(format(input, &QueryParams::None, &options), expected); } + #[test] + fn it_does_format_drop() { + let input = indoc!( + " + DROP INDEX IF EXISTS idx_a; + DROP INDEX IF EXISTS idx_b; + " + ); + + let options = FormatOptions { + ..Default::default() + }; + + let expected = indoc!( + " + DROP INDEX IF EXISTS + idx_a; + DROP INDEX IF EXISTS + idx_b;" + ); + + assert_eq!(format(input, &QueryParams::None, &options), expected); + + let input = indoc!( + r#" + -- comment + DROP TABLE IF EXISTS "public"."table_name"; + "# + ); + + let expected = indoc!( + r#" + -- comment + DROP TABLE IF EXISTS + "public"."table_name";"# + ); + + assert_eq!(format(input, &QueryParams::None, &options), expected); + } + #[test] fn it_formats_select_query_with_inner_join() { let input = indoc!( @@ -995,8 +1035,12 @@ mod tests { fn it_formats_simple_drop_query() { let input = "DROP TABLE IF EXISTS admin_role;"; let options = FormatOptions::default(); - - assert_eq!(format(input, &QueryParams::None, &options), input); + let output = indoc!( + " + DROP TABLE IF EXISTS + admin_role;" + ); + assert_eq!(format(input, &QueryParams::None, &options), output); } #[test] @@ -1418,8 +1462,30 @@ mod tests { " ALTER TABLE supplier - ALTER COLUMN - supplier_name VARCHAR(100) NOT NULL;" + ALTER COLUMN supplier_name VARCHAR(100) NOT NULL;" + ); + + assert_eq!(format(input, &QueryParams::None, &options), expected); + } + + #[test] + fn it_formats_alter_table_add_and_drop() { + let input = r#"ALTER TABLE "public"."event" DROP CONSTRAINT "validate_date", ADD CONSTRAINT "validate_date" CHECK (end_date IS NULL + OR (start_date IS NOT NULL AND end_date > start_date));"#; + + let options = FormatOptions::default(); + let expected = indoc!( + r#" + ALTER TABLE + "public"."event" + DROP CONSTRAINT "validate_date", + ADD CONSTRAINT "validate_date" CHECK ( + end_date IS NULL + OR ( + start_date IS NOT NULL + AND end_date > start_date + ) + );"# ); assert_eq!(format(input, &QueryParams::None, &options), expected); @@ -2196,8 +2262,7 @@ mod tests { -- 自动加载数据到 Hive 分区中 ALTER TABLE sales_data - ADD - PARTITION (sale_year = '2024', sale_month = '08') LOCATION '/user/hive/warehouse/sales_data/2024/08';" + ADD PARTITION (sale_year = '2024', sale_month = '08') LOCATION '/user/hive/warehouse/sales_data/2024/08';" ); assert_eq!(format(input, &QueryParams::None, &options), expected); diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 065c20d..3ddaa6f 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -488,35 +488,60 @@ fn get_top_level_reserved_token<'a>( // First peek at the first character to determine which group to check let first_char = peek(any).parse_next(input)?.to_ascii_uppercase(); + let create_or_replace = ( + "AGGREGATE", + "FUNCTION", + "LANGUAGE", + "PROCEDURE", + "RULE", + "TRIGGER", + "VIEW", + ); + + let alterable_or_droppable = alt((alt(create_or_replace), "TABLE", "INDEX")); + let create_or_replace = alt(create_or_replace); + // Match keywords based on their first letter let result: Result<&str> = match first_char { 'A' => alt(( - terminated("ADD", end_of_word), terminated("AFTER", end_of_word), - terminated("ALTER COLUMN", end_of_word), - terminated("ALTER TABLE", end_of_word), + terminated(("ALTER ", alterable_or_droppable).take(), end_of_word), )) .parse_next(&mut uc_input), 'C' => terminated( ( "CREATE ", - opt(alt(( - "UNLOGGED ", + alt(( + create_or_replace, + (opt("UNIQUE "), "INDEX").take(), ( - alt(("GLOBAL ", "LOCAL ")), - opt(alt(("TEMPORARY ", "TEMP "))), + opt(alt(( + "UNLOGGED ", + ( + alt(("GLOBAL ", "LOCAL ")), + opt(alt(("TEMPORARY ", "TEMP "))), + ) + .take(), + ))), + "TABLE", ) .take(), - ))), - "TABLE", + )), ) .take(), end_of_word, ) .parse_next(&mut uc_input), - 'D' => terminated("DELETE FROM", end_of_word).parse_next(&mut uc_input), + 'D' => alt(( + terminated("DELETE FROM", end_of_word), + terminated( + ("DROP ", alterable_or_droppable, opt(" IF EXISTS")).take(), + end_of_word, + ), + )) + .parse_next(&mut uc_input), 'E' => terminated("EXCEPT", end_of_word).parse_next(&mut uc_input),